feat: impl StructArray -- support embedding searches embeddings in embedding list with element level filter expression (#45830)

issue: https://github.com/milvus-io/milvus/issues/42148

For a vector field inside a STRUCT, since a STRUCT can only appear as
the element type of an ARRAY field, the vector field in STRUCT is
effectively an array of vectors, i.e. an embedding list.
Milvus already supports searching embedding lists with metrics whose
names start with the prefix MAX_SIM_.

This PR allows Milvus to search embeddings inside an embedding list
using the same metrics as normal embedding fields. Each embedding in the
list is treated as an independent vector and participates in ANN search.

Further, since STRUCT may contain scalar fields that are highly related
to the embedding field, this PR introduces an element-level filter
expression to refine search results.
The grammar of the element-level filter is:

element_filter(structFieldName, $[subFieldName] == 3)

where $[subFieldName] refers to the value of subFieldName in each
element of the STRUCT array structFieldName.

It can be combined with existing filter expressions, for example:

"varcharField == 'aaa' && element_filter(struct_field, $[struct_int] ==
3)"

A full example:
```
struct_schema = milvus_client.create_struct_field_schema()
struct_schema.add_field("struct_str", DataType.VARCHAR, max_length=65535)
struct_schema.add_field("struct_int", DataType.INT32)
struct_schema.add_field("struct_float_vec", DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM)

schema.add_field(
    "struct_field",
    datatype=DataType.ARRAY,
    element_type=DataType.STRUCT,
    struct_schema=struct_schema,
    max_capacity=1000,
)
...

filter = "varcharField == 'aaa' && element_filter(struct_field, $[struct_int] == 3 && $[struct_str] == 'abc')"
res = milvus_client.search(
    COLLECTION_NAME,
    data=query_embeddings,
    limit=10,
    anns_field="struct_field[struct_float_vec]",
    filter=filter,
    output_fields=["struct_field[struct_int]", "varcharField"],
)

```
TODO:
1. When an `element_filter` expression is used, a regular filter
expression must also be present. Remove this restriction.
2. Implement `element_filter` expressions in the `query`.

---------

Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
This commit is contained in:
Spade A 2025-12-15 12:01:15 +08:00 committed by GitHub
parent ca2e27f576
commit f6f716bcfd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
72 changed files with 8180 additions and 3680 deletions

View File

@ -0,0 +1,292 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ArrayOffsets.h"
#include "segcore/SegmentInterface.h"
#include "log/Log.h"
#include "common/EasyAssert.h"
namespace milvus {
std::pair<int32_t, int32_t>
ArrayOffsetsSealed::ElementIDToRowID(int32_t elem_id) const {
assert(elem_id >= 0 && elem_id < GetTotalElementCount());
int32_t row_id = element_row_ids_[elem_id];
// Compute elem_idx: elem_idx = elem_id - start_of_this_row
int32_t elem_idx = elem_id - row_to_element_start_[row_id];
return {row_id, elem_idx};
}
std::pair<int32_t, int32_t>
ArrayOffsetsSealed::ElementIDRangeOfRow(int32_t row_id) const {
int32_t row_count = GetRowCount();
assert(row_id >= 0 && row_id <= row_count);
if (row_id == row_count) {
auto total = row_to_element_start_[row_count];
return {total, total};
}
return {row_to_element_start_[row_id], row_to_element_start_[row_id + 1]};
}
std::pair<TargetBitmap, TargetBitmap>
ArrayOffsetsSealed::RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const {
int64_t row_count = GetRowCount();
int64_t element_count = GetTotalElementCount();
TargetBitmap element_bitset(element_count);
TargetBitmap valid_element_bitset(element_count);
for (int64_t row_id = 0; row_id < row_count; ++row_id) {
int64_t start = row_to_element_start_[row_id];
int64_t end = row_to_element_start_[row_id + 1];
if (start < end) {
element_bitset.set(start, end - start, row_bitset[row_id]);
valid_element_bitset.set(
start, end - start, valid_row_bitset[row_id]);
}
}
return {std::move(element_bitset), std::move(valid_element_bitset)};
}
std::shared_ptr<ArrayOffsetsSealed>
ArrayOffsetsSealed::BuildFromSegment(const void* segment,
const FieldMeta& field_meta) {
auto seg = static_cast<const segcore::SegmentInternalInterface*>(segment);
int64_t row_count = seg->get_row_count();
if (row_count == 0) {
LOG_INFO(
"ArrayOffsetsSealed::BuildFromSegment: empty segment for struct "
"'{}'",
field_meta.get_name().get());
return std::make_shared<ArrayOffsetsSealed>(std::vector<int32_t>{},
std::vector<int32_t>{0});
}
FieldId field_id = field_meta.get_id();
auto data_type = field_meta.get_data_type();
std::vector<int32_t> element_row_ids;
// Size is row_count + 1, last element stores total_element_count
std::vector<int32_t> row_to_element_start(row_count + 1);
auto temp_op_ctx = std::make_unique<OpContext>();
auto op_ctx_ptr = temp_op_ctx.get();
int64_t num_chunks = seg->num_chunk(field_id);
int32_t current_row_id = 0;
if (data_type == DataType::VECTOR_ARRAY) {
for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) {
auto pin_wrapper = seg->chunk_view<VectorArrayView>(
op_ctx_ptr, field_id, chunk_id);
const auto& [vector_array_views, valid_flags] = pin_wrapper.get();
for (size_t i = 0; i < vector_array_views.size(); ++i) {
int32_t array_len = 0;
if (valid_flags.empty() || valid_flags[i]) {
array_len = vector_array_views[i].length();
}
// Record the start position for this row
row_to_element_start[current_row_id] = element_row_ids.size();
// Add row_id for each element (elem_idx computed on access)
for (int32_t j = 0; j < array_len; ++j) {
element_row_ids.emplace_back(current_row_id);
}
current_row_id++;
}
}
} else {
for (int64_t chunk_id = 0; chunk_id < num_chunks; ++chunk_id) {
auto pin_wrapper =
seg->chunk_view<ArrayView>(op_ctx_ptr, field_id, chunk_id);
const auto& [array_views, valid_flags] = pin_wrapper.get();
for (size_t i = 0; i < array_views.size(); ++i) {
int32_t array_len = 0;
if (valid_flags.empty() || valid_flags[i]) {
array_len = array_views[i].length();
}
// Record the start position for this row
row_to_element_start[current_row_id] = element_row_ids.size();
// Add row_id for each element (elem_idx computed on access)
for (int32_t j = 0; j < array_len; ++j) {
element_row_ids.emplace_back(current_row_id);
}
current_row_id++;
}
}
}
// Store total element count as the last entry
row_to_element_start[row_count] = element_row_ids.size();
AssertInfo(current_row_id == row_count,
"Row count mismatch: expected {}, got {}",
row_count,
current_row_id);
int64_t total_elements = element_row_ids.size();
LOG_INFO(
"ArrayOffsetsSealed::BuildFromSegment: struct_name='{}', "
"field_id={}, row_count={}, total_elements={}",
field_meta.get_name().get(),
field_meta.get_id().get(),
row_count,
total_elements);
auto result = std::make_shared<ArrayOffsetsSealed>(
std::move(element_row_ids), std::move(row_to_element_start));
result->resource_size_ = 4 * (row_count + 1) + 4 * total_elements;
cachinglayer::Manager::GetInstance().ChargeLoadedResource(
cachinglayer::ResourceUsage{result->resource_size_, 0});
return result;
}
std::pair<int32_t, int32_t>
ArrayOffsetsGrowing::ElementIDToRowID(int32_t elem_id) const {
std::shared_lock lock(mutex_);
assert(elem_id >= 0 &&
elem_id < static_cast<int32_t>(element_row_ids_.size()));
int32_t row_id = element_row_ids_[elem_id];
// Compute elem_idx: elem_idx = elem_id - start_of_this_row
int32_t elem_idx = elem_id - row_to_element_start_[row_id];
return {row_id, elem_idx};
}
std::pair<int32_t, int32_t>
ArrayOffsetsGrowing::ElementIDRangeOfRow(int32_t row_id) const {
std::shared_lock lock(mutex_);
assert(row_id >= 0 && row_id <= committed_row_count_);
if (row_id == committed_row_count_) {
auto total = row_to_element_start_[committed_row_count_];
return {total, total};
}
return {row_to_element_start_[row_id], row_to_element_start_[row_id + 1]};
}
std::pair<TargetBitmap, TargetBitmap>
ArrayOffsetsGrowing::RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const {
std::shared_lock lock(mutex_);
int64_t element_count = element_row_ids_.size();
TargetBitmap element_bitset(element_count);
TargetBitmap valid_element_bitset(element_count);
// Direct access to element_row_ids_, no virtual function calls
for (size_t elem_id = 0; elem_id < element_row_ids_.size(); ++elem_id) {
auto row_id = element_row_ids_[elem_id];
element_bitset[elem_id] = row_bitset[row_id];
valid_element_bitset[elem_id] = valid_row_bitset[row_id];
}
return {std::move(element_bitset), std::move(valid_element_bitset)};
}
void
ArrayOffsetsGrowing::Insert(int64_t row_id_start,
const int32_t* array_lengths,
int64_t count) {
std::unique_lock lock(mutex_);
row_to_element_start_.reserve(row_id_start + count + 1);
int32_t original_committed_count = committed_row_count_;
for (int64_t i = 0; i < count; ++i) {
int32_t row_id = row_id_start + i;
int32_t array_len = array_lengths[i];
if (row_id == committed_row_count_) {
// Record the start position for this row
// If sentinel exists at current position, overwrite it; otherwise push_back
if (row_to_element_start_.size() >
static_cast<size_t>(committed_row_count_)) {
row_to_element_start_[committed_row_count_] =
element_row_ids_.size();
} else {
row_to_element_start_.push_back(element_row_ids_.size());
}
// Add row_id for each element (elem_idx computed on access)
for (int32_t j = 0; j < array_len; ++j) {
element_row_ids_.emplace_back(row_id);
}
committed_row_count_++;
} else {
pending_rows_[row_id] = {row_id, array_len};
}
}
DrainPendingRows();
// Update the sentinel (total element count) only if we committed new rows
if (committed_row_count_ > original_committed_count) {
if (row_to_element_start_.size() ==
static_cast<size_t>(committed_row_count_)) {
row_to_element_start_.push_back(element_row_ids_.size());
} else {
row_to_element_start_[committed_row_count_] =
element_row_ids_.size();
}
}
}
void
ArrayOffsetsGrowing::DrainPendingRows() {
while (true) {
auto it = pending_rows_.find(committed_row_count_);
if (it == pending_rows_.end()) {
break;
}
const auto& pending = it->second;
// If sentinel exists at current position, overwrite it; otherwise push_back
if (row_to_element_start_.size() >
static_cast<size_t>(committed_row_count_)) {
row_to_element_start_[committed_row_count_] =
element_row_ids_.size();
} else {
row_to_element_start_.push_back(element_row_ids_.size());
}
for (int32_t j = 0; j < pending.array_len; ++j) {
element_row_ids_.emplace_back(static_cast<int32_t>(pending.row_id));
}
committed_row_count_++;
pending_rows_.erase(it);
}
}
} // namespace milvus

View File

@ -0,0 +1,164 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include <utility>
#include <map>
#include <memory>
#include <shared_mutex>
#include "cachinglayer/Manager.h"
#include "common/Types.h"
#include "common/EasyAssert.h"
#include "common/FieldMeta.h"
namespace milvus {
class IArrayOffsets {
public:
virtual ~IArrayOffsets() = default;
virtual int64_t
GetRowCount() const = 0;
virtual int64_t
GetTotalElementCount() const = 0;
// Convert element ID to row ID
// returns pair of <row_id, element_index>
// element id is contiguous between rows
virtual std::pair<int32_t, int32_t>
ElementIDToRowID(int32_t elem_id) const = 0;
// Convert row ID to element ID range
// elements with id in [ret.first, ret.last) belong to row_id
virtual std::pair<int32_t, int32_t>
ElementIDRangeOfRow(int32_t row_id) const = 0;
// Convert row-level bitsets to element-level bitsets
virtual std::pair<TargetBitmap, TargetBitmap>
RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const = 0;
};
class ArrayOffsetsSealed : public IArrayOffsets {
friend class ArrayOffsetsTest;
public:
ArrayOffsetsSealed() : element_row_ids_(), row_to_element_start_({0}) {
}
ArrayOffsetsSealed(std::vector<int32_t> element_row_ids,
std::vector<int32_t> row_to_element_start)
: element_row_ids_(std::move(element_row_ids)),
row_to_element_start_(std::move(row_to_element_start)) {
AssertInfo(!row_to_element_start_.empty(),
"row_to_element_start must have at least one element");
}
~ArrayOffsetsSealed() {
cachinglayer::Manager::GetInstance().RefundLoadedResource(
{resource_size_, 0});
}
int64_t
GetRowCount() const override {
return static_cast<int64_t>(row_to_element_start_.size()) - 1;
}
int64_t
GetTotalElementCount() const override {
return element_row_ids_.size();
}
std::pair<int32_t, int32_t>
ElementIDToRowID(int32_t elem_id) const override;
std::pair<int32_t, int32_t>
ElementIDRangeOfRow(int32_t row_id) const override;
std::pair<TargetBitmap, TargetBitmap>
RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const override;
static std::shared_ptr<ArrayOffsetsSealed>
BuildFromSegment(const void* segment, const FieldMeta& field_meta);
private:
const std::vector<int32_t> element_row_ids_;
const std::vector<int32_t> row_to_element_start_;
int64_t resource_size_{0};
};
class ArrayOffsetsGrowing : public IArrayOffsets {
public:
ArrayOffsetsGrowing() = default;
void
Insert(int64_t row_id_start, const int32_t* array_lengths, int64_t count);
int64_t
GetRowCount() const override {
std::shared_lock lock(mutex_);
return committed_row_count_;
}
int64_t
GetTotalElementCount() const override {
std::shared_lock lock(mutex_);
return element_row_ids_.size();
}
std::pair<int32_t, int32_t>
ElementIDToRowID(int32_t elem_id) const override;
std::pair<int32_t, int32_t>
ElementIDRangeOfRow(int32_t row_id) const override;
std::pair<TargetBitmap, TargetBitmap>
RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const override;
private:
struct PendingRow {
int64_t row_id;
int32_t array_len;
};
void
DrainPendingRows();
private:
std::vector<int32_t> element_row_ids_;
std::vector<int32_t> row_to_element_start_;
// Number of rows committed (contiguous from 0)
int32_t committed_row_count_ = 0;
// Pending rows waiting for earlier rows to complete
// Key: row_id, automatically sorted
std::map<int64_t, PendingRow> pending_rows_;
// Protects all member variables
mutable std::shared_mutex mutex_;
};
} // namespace milvus

View File

@ -0,0 +1,427 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <thread>
#include <vector>
#include "common/ArrayOffsets.h"
using namespace milvus;
class ArrayOffsetsTest : public ::testing::Test {
protected:
void
SetUp() override {
}
};
TEST_F(ArrayOffsetsTest, SealedBasic) {
// Create a simple ArrayOffsetsSealed manually
// row 0: 2 elements (elem 0, 1)
// row 1: 3 elements (elem 2, 3, 4)
// row 2: 1 element (elem 5)
ArrayOffsetsSealed offsets(
{0, 0, 1, 1, 1, 2}, // element_row_ids
{0, 2, 5, 6} // row_to_element_start (size = row_count + 1)
);
// Test GetRowCount
EXPECT_EQ(offsets.GetRowCount(), 3);
// Test GetTotalElementCount
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
// Test ElementIDToRowID
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
EXPECT_EQ(row_id, 0);
EXPECT_EQ(elem_idx, 0);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(1);
EXPECT_EQ(row_id, 0);
EXPECT_EQ(elem_idx, 1);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(2);
EXPECT_EQ(row_id, 1);
EXPECT_EQ(elem_idx, 0);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(4);
EXPECT_EQ(row_id, 1);
EXPECT_EQ(elem_idx, 2);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5);
EXPECT_EQ(row_id, 2);
EXPECT_EQ(elem_idx, 0);
}
// Test ElementIDRangeOfRow
{
auto [start, end] = offsets.ElementIDRangeOfRow(0);
EXPECT_EQ(start, 0);
EXPECT_EQ(end, 2);
}
{
auto [start, end] = offsets.ElementIDRangeOfRow(1);
EXPECT_EQ(start, 2);
EXPECT_EQ(end, 5);
}
{
auto [start, end] = offsets.ElementIDRangeOfRow(2);
EXPECT_EQ(start, 5);
EXPECT_EQ(end, 6);
}
// When row_id == row_count, return (total_elements, total_elements)
{
auto [start, end] = offsets.ElementIDRangeOfRow(3);
EXPECT_EQ(start, 6);
EXPECT_EQ(end, 6);
}
}
TEST_F(ArrayOffsetsTest, SealedRowBitsetToElementBitset) {
ArrayOffsetsSealed offsets({0, 0, 1, 1, 1, 2}, // element_row_ids
{0, 2, 5, 6} // row_to_element_start
);
// row_bitset: row 0 = true, row 1 = false, row 2 = true
TargetBitmap row_bitset(3);
row_bitset[0] = true;
row_bitset[1] = false;
row_bitset[2] = true;
TargetBitmap valid_row_bitset(3, true);
TargetBitmapView row_view(row_bitset.data(), row_bitset.size());
TargetBitmapView valid_view(valid_row_bitset.data(),
valid_row_bitset.size());
auto [elem_bitset, valid_elem_bitset] =
offsets.RowBitsetToElementBitset(row_view, valid_view);
EXPECT_EQ(elem_bitset.size(), 6);
// Elements of row 0 (elem 0, 1) should be true
EXPECT_TRUE(elem_bitset[0]);
EXPECT_TRUE(elem_bitset[1]);
// Elements of row 1 (elem 2, 3, 4) should be false
EXPECT_FALSE(elem_bitset[2]);
EXPECT_FALSE(elem_bitset[3]);
EXPECT_FALSE(elem_bitset[4]);
// Elements of row 2 (elem 5) should be true
EXPECT_TRUE(elem_bitset[5]);
}
TEST_F(ArrayOffsetsTest, SealedEmptyArrays) {
// Test with some rows having empty arrays
// row 1 and row 3 are empty
ArrayOffsetsSealed offsets({0, 0, 2, 2, 2}, // element_row_ids
{0, 2, 2, 5, 5} // row_to_element_start
);
EXPECT_EQ(offsets.GetRowCount(), 4);
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
// Row 1 has no elements
{
auto [start, end] = offsets.ElementIDRangeOfRow(1);
EXPECT_EQ(start, 2);
EXPECT_EQ(end, 2); // empty range
}
// Row 3 has no elements
{
auto [start, end] = offsets.ElementIDRangeOfRow(3);
EXPECT_EQ(start, 5);
EXPECT_EQ(end, 5); // empty range
}
}
TEST_F(ArrayOffsetsTest, GrowingBasicInsert) {
ArrayOffsetsGrowing offsets;
// Insert rows in order
std::vector<int32_t> lens1 = {2}; // row 0: 2 elements
offsets.Insert(0, lens1.data(), 1);
std::vector<int32_t> lens2 = {3}; // row 1: 3 elements
offsets.Insert(1, lens2.data(), 1);
std::vector<int32_t> lens3 = {1}; // row 2: 1 element
offsets.Insert(2, lens3.data(), 1);
EXPECT_EQ(offsets.GetRowCount(), 3);
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
// Test ElementIDToRowID
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
EXPECT_EQ(row_id, 0);
EXPECT_EQ(elem_idx, 0);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(2);
EXPECT_EQ(row_id, 1);
EXPECT_EQ(elem_idx, 0);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5);
EXPECT_EQ(row_id, 2);
EXPECT_EQ(elem_idx, 0);
}
// Test ElementIDRangeOfRow
{
auto [start, end] = offsets.ElementIDRangeOfRow(0);
EXPECT_EQ(start, 0);
EXPECT_EQ(end, 2);
}
{
auto [start, end] = offsets.ElementIDRangeOfRow(1);
EXPECT_EQ(start, 2);
EXPECT_EQ(end, 5);
}
// When row_id == row_count, return (total_elements, total_elements)
{
auto [start, end] = offsets.ElementIDRangeOfRow(3);
EXPECT_EQ(start, 6);
EXPECT_EQ(end, 6);
}
}
TEST_F(ArrayOffsetsTest, GrowingBatchInsert) {
ArrayOffsetsGrowing offsets;
// Insert multiple rows at once
std::vector<int32_t> lens = {2, 3, 1}; // row 0, 1, 2
offsets.Insert(0, lens.data(), 3);
EXPECT_EQ(offsets.GetRowCount(), 3);
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
{
auto [start, end] = offsets.ElementIDRangeOfRow(0);
EXPECT_EQ(start, 0);
EXPECT_EQ(end, 2);
}
{
auto [start, end] = offsets.ElementIDRangeOfRow(1);
EXPECT_EQ(start, 2);
EXPECT_EQ(end, 5);
}
{
auto [start, end] = offsets.ElementIDRangeOfRow(2);
EXPECT_EQ(start, 5);
EXPECT_EQ(end, 6);
}
}
TEST_F(ArrayOffsetsTest, GrowingOutOfOrderInsert) {
ArrayOffsetsGrowing offsets;
// Insert out of order - row 2 arrives before row 1
std::vector<int32_t> lens0 = {2};
offsets.Insert(0, lens0.data(), 1); // row 0
std::vector<int32_t> lens2 = {1};
offsets.Insert(2, lens2.data(), 1); // row 2 (pending)
// row 1 not inserted yet, so only row 0 should be committed
EXPECT_EQ(offsets.GetRowCount(), 1);
EXPECT_EQ(offsets.GetTotalElementCount(), 2);
// Now insert row 1, which should drain pending row 2
std::vector<int32_t> lens1 = {3};
offsets.Insert(1, lens1.data(), 1); // row 1
// Now all 3 rows should be committed
EXPECT_EQ(offsets.GetRowCount(), 3);
EXPECT_EQ(offsets.GetTotalElementCount(), 6);
// Verify order is correct
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
EXPECT_EQ(row_id, 0);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(2);
EXPECT_EQ(row_id, 1);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5);
EXPECT_EQ(row_id, 2);
}
}
TEST_F(ArrayOffsetsTest, GrowingEmptyArrays) {
ArrayOffsetsGrowing offsets;
// Insert rows with some empty arrays
std::vector<int32_t> lens = {2, 0, 3, 0}; // row 1 and row 3 are empty
offsets.Insert(0, lens.data(), 4);
EXPECT_EQ(offsets.GetRowCount(), 4);
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
// Row 1 has no elements
{
auto [start, end] = offsets.ElementIDRangeOfRow(1);
EXPECT_EQ(start, 2);
EXPECT_EQ(end, 2);
}
// Row 3 has no elements
{
auto [start, end] = offsets.ElementIDRangeOfRow(3);
EXPECT_EQ(start, 5);
EXPECT_EQ(end, 5);
}
}
TEST_F(ArrayOffsetsTest, GrowingRowBitsetToElementBitset) {
ArrayOffsetsGrowing offsets;
std::vector<int32_t> lens = {2, 3, 1};
offsets.Insert(0, lens.data(), 3);
TargetBitmap row_bitset(3);
row_bitset[0] = true;
row_bitset[1] = false;
row_bitset[2] = true;
TargetBitmap valid_row_bitset(3, true);
TargetBitmapView row_view(row_bitset.data(), row_bitset.size());
TargetBitmapView valid_view(valid_row_bitset.data(),
valid_row_bitset.size());
auto [elem_bitset, valid_elem_bitset] =
offsets.RowBitsetToElementBitset(row_view, valid_view);
EXPECT_EQ(elem_bitset.size(), 6);
EXPECT_TRUE(elem_bitset[0]);
EXPECT_TRUE(elem_bitset[1]);
EXPECT_FALSE(elem_bitset[2]);
EXPECT_FALSE(elem_bitset[3]);
EXPECT_FALSE(elem_bitset[4]);
EXPECT_TRUE(elem_bitset[5]);
}
TEST_F(ArrayOffsetsTest, GrowingConcurrentRead) {
ArrayOffsetsGrowing offsets;
// Insert initial data
std::vector<int32_t> lens = {2, 3, 1};
offsets.Insert(0, lens.data(), 3);
// Concurrent reads should be safe
std::vector<std::thread> threads;
for (int t = 0; t < 4; ++t) {
threads.emplace_back([&offsets]() {
for (int i = 0; i < 1000; ++i) {
auto row_count = offsets.GetRowCount();
auto elem_count = offsets.GetTotalElementCount();
EXPECT_GE(row_count, 0);
EXPECT_GE(elem_count, 0);
if (row_count > 0) {
auto [start, end] = offsets.ElementIDRangeOfRow(0);
EXPECT_GE(start, 0);
EXPECT_GE(end, start);
}
if (elem_count > 0) {
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
EXPECT_GE(row_id, 0);
EXPECT_GE(elem_idx, 0);
}
}
});
}
for (auto& t : threads) {
t.join();
}
}
TEST_F(ArrayOffsetsTest, SingleRow) {
ArrayOffsetsGrowing offsets;
std::vector<int32_t> lens = {5};
offsets.Insert(0, lens.data(), 1);
EXPECT_EQ(offsets.GetRowCount(), 1);
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
for (int i = 0; i < 5; ++i) {
auto [row_id, elem_idx] = offsets.ElementIDToRowID(i);
EXPECT_EQ(row_id, 0);
EXPECT_EQ(elem_idx, i);
}
auto [start, end] = offsets.ElementIDRangeOfRow(0);
EXPECT_EQ(start, 0);
EXPECT_EQ(end, 5);
}
TEST_F(ArrayOffsetsTest, SingleElementPerRow) {
ArrayOffsetsGrowing offsets;
std::vector<int32_t> lens = {1, 1, 1, 1, 1};
offsets.Insert(0, lens.data(), 5);
EXPECT_EQ(offsets.GetRowCount(), 5);
EXPECT_EQ(offsets.GetTotalElementCount(), 5);
for (int i = 0; i < 5; ++i) {
auto [row_id, elem_idx] = offsets.ElementIDToRowID(i);
EXPECT_EQ(row_id, i);
EXPECT_EQ(elem_idx, 0);
auto [start, end] = offsets.ElementIDRangeOfRow(i);
EXPECT_EQ(start, i);
EXPECT_EQ(end, i + 1);
}
}
TEST_F(ArrayOffsetsTest, LargeArrayLength) {
ArrayOffsetsGrowing offsets;
// Single row with many elements
std::vector<int32_t> lens = {10000};
offsets.Insert(0, lens.data(), 1);
EXPECT_EQ(offsets.GetRowCount(), 1);
EXPECT_EQ(offsets.GetTotalElementCount(), 10000);
// Test first, middle, and last elements
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(0);
EXPECT_EQ(row_id, 0);
EXPECT_EQ(elem_idx, 0);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(5000);
EXPECT_EQ(row_id, 0);
EXPECT_EQ(elem_idx, 5000);
}
{
auto [row_id, elem_idx] = offsets.ElementIDToRowID(9999);
EXPECT_EQ(row_id, 0);
EXPECT_EQ(elem_idx, 9999);
}
}

View File

@ -0,0 +1,115 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ElementFilterIterator.h"
#include "common/EasyAssert.h"
#include "exec/expression/EvalCtx.h"
#include "exec/expression/Expr.h"
namespace milvus {
ElementFilterIterator::ElementFilterIterator(
std::shared_ptr<VectorIterator> base_iterator,
exec::ExecContext* exec_context,
exec::ExprSet* expr_set)
: base_iterator_(std::move(base_iterator)),
exec_context_(exec_context),
expr_set_(expr_set) {
AssertInfo(base_iterator_ != nullptr, "Base iterator cannot be null");
AssertInfo(exec_context_ != nullptr, "Exec context cannot be null");
AssertInfo(expr_set_ != nullptr, "ExprSet cannot be null");
}
bool
ElementFilterIterator::HasNext() {
// If cache is empty and base iterator has more, fetch more
while (filtered_buffer_.empty() && base_iterator_->HasNext()) {
FetchAndFilterBatch();
}
return !filtered_buffer_.empty();
}
std::optional<std::pair<int64_t, float>>
ElementFilterIterator::Next() {
if (!HasNext()) {
return std::nullopt;
}
auto result = filtered_buffer_.front();
filtered_buffer_.pop_front();
return result;
}
void
ElementFilterIterator::FetchAndFilterBatch() {
constexpr size_t kBatchSize = 1024;
// Step 1: Fetch a batch from base iterator (up to kBatchSize elements)
element_ids_buffer_.clear();
distances_buffer_.clear();
element_ids_buffer_.reserve(kBatchSize);
distances_buffer_.reserve(kBatchSize);
while (base_iterator_->HasNext() &&
element_ids_buffer_.size() < kBatchSize) {
auto pair = base_iterator_->Next();
if (pair.has_value()) {
element_ids_buffer_.push_back(static_cast<int32_t>(pair->first));
distances_buffer_.push_back(pair->second);
}
}
// If no elements fetched, base iterator is exhausted
if (element_ids_buffer_.empty()) {
return;
}
// Step 2: Batch evaluate element-level expression
exec::EvalCtx eval_ctx(exec_context_, expr_set_, &element_ids_buffer_);
std::vector<VectorPtr> results;
// Evaluate the expression set (should contain only element_expr)
expr_set_->Eval(0, 1, true, eval_ctx, results);
AssertInfo(results.size() == 1 && results[0] != nullptr,
"ElementFilterIterator: expression evaluation should return "
"exactly one result");
// Step 3: Extract evaluation results as bitmap
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
AssertInfo(col_vec != nullptr,
"ElementFilterIterator: result should be ColumnVector");
AssertInfo(col_vec->IsBitmap(),
"ElementFilterIterator: result should be bitmap");
auto col_vec_size = col_vec->size();
AssertInfo(col_vec_size == element_ids_buffer_.size(),
"ElementFilterIterator: evaluation result size mismatch");
TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size);
// Step 4: Filter elements based on evaluation results and cache them
for (size_t i = 0; i < element_ids_buffer_.size(); ++i) {
if (bitsetview[i]) {
// Element passes filter, cache it
filtered_buffer_.emplace_back(element_ids_buffer_[i],
distances_buffer_[i]);
}
}
}
} // namespace milvus

View File

@ -0,0 +1,74 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <memory>
#include <optional>
#include "common/QueryResult.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
namespace milvus {
// Forward declarations
namespace exec {
class ExecContext;
class ExprSet;
} // namespace exec
class ElementFilterIterator : public VectorIterator {
public:
ElementFilterIterator(std::shared_ptr<VectorIterator> base_iterator,
exec::ExecContext* exec_context,
exec::ExprSet* expr_set);
bool
HasNext() override;
std::optional<std::pair<int64_t, float>>
Next() override;
private:
// Fetch a batch from base iterator, evaluate expression, and cache results
// Steps:
// 1. Fetch up to batch_size elements from base_iterator
// 2. Batch evaluate element_expr on fetched elements
// 3. Filter elements based on evaluation results
// 4. Cache passing elements in filtered_buffer_
void
FetchAndFilterBatch();
// Base iterator to fetch elements from
std::shared_ptr<VectorIterator> base_iterator_;
// Execution context for expression evaluation
exec::ExecContext* exec_context_;
// Expression set containing element-level filter expression
exec::ExprSet* expr_set_;
// Cache of filtered elements ready to be consumed
std::deque<std::pair<int64_t, float>> filtered_buffer_;
// Reusable buffers for batch fetching (avoid repeated allocations)
FixedVector<int32_t> element_ids_buffer_;
FixedVector<float> distances_buffer_;
};
} // namespace milvus

View File

@ -18,6 +18,7 @@
#include <memory>
#include "ArrayOffsets.h"
#include "common/Tracer.h"
#include "common/Types.h"
#include "knowhere/config.h"
@ -46,6 +47,13 @@ struct SearchInfo {
std::optional<std::string> json_path_;
std::optional<milvus::DataType> json_type_;
bool strict_cast_{false};
std::shared_ptr<const IArrayOffsets> array_offsets_{
nullptr}; // For element-level search
bool
element_level() const {
return array_offsets_ != nullptr;
}
};
using SearchInfoPtr = std::shared_ptr<SearchInfo>;

View File

@ -28,6 +28,7 @@
#include <NamedType/named_type.hpp>
#include "common/FieldMeta.h"
#include "common/ArrayOffsets.h"
#include "pb/schema.pb.h"
#include "knowhere/index/index_node.h"
@ -122,16 +123,37 @@ struct OffsetDisPairComparator {
return left->GetOffDis().first < right->GetOffDis().first;
}
};
struct VectorIterator {
class VectorIterator {
public:
VectorIterator(int chunk_count,
const std::vector<int64_t>& total_rows_until_chunk = {})
virtual ~VectorIterator() = default;
virtual bool
HasNext() = 0;
virtual std::optional<std::pair<int64_t, float>>
Next() = 0;
};
// Multi-way merge iterator for vector search results from multiple chunks
//
// Merges knowhere iterators from different chunks using a min-heap,
// returning results in distance-sorted order.
class ChunkMergeIterator : public VectorIterator {
public:
ChunkMergeIterator(int chunk_count,
const std::vector<int64_t>& total_rows_until_chunk = {})
: total_rows_until_chunk_(total_rows_until_chunk) {
iterators_.reserve(chunk_count);
}
bool
HasNext() override {
return !heap_.empty();
}
std::optional<std::pair<int64_t, float>>
Next() {
Next() override {
if (!heap_.empty()) {
auto top = heap_.top();
heap_.pop();
@ -145,10 +167,7 @@ struct VectorIterator {
}
return std::nullopt;
}
bool
HasNext() {
return !heap_.empty();
}
bool
AddIterator(knowhere::IndexNode::IteratorPtr iter) {
if (!sealed && iter != nullptr) {
@ -157,6 +176,7 @@ struct VectorIterator {
}
return false;
}
void
seal() {
sealed = true;
@ -195,7 +215,7 @@ struct VectorIterator {
heap_;
bool sealed = false;
std::vector<int64_t> total_rows_until_chunk_;
//currently, VectorIterator is guaranteed to be used serially without concurrent problem, in the future
//currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future
//we may need to add mutex to protect the variable sealed
};
@ -230,15 +250,21 @@ struct SearchResult {
for (int i = 0, vec_iter_idx = 0; i < kw_iterators.size(); i++) {
vec_iter_idx = vec_iter_idx % nq;
if (vector_iterators.size() < nq) {
auto vector_iterator = std::make_shared<VectorIterator>(
auto chunk_merge_iter = std::make_shared<ChunkMergeIterator>(
chunk_count, total_rows_until_chunk);
vector_iterators.emplace_back(vector_iterator);
vector_iterators.emplace_back(chunk_merge_iter);
}
const auto& kw_iterator = kw_iterators[i];
vector_iterators[vec_iter_idx++]->AddIterator(kw_iterator);
auto chunk_merge_iter =
std::static_pointer_cast<ChunkMergeIterator>(
vector_iterators[vec_iter_idx++]);
chunk_merge_iter->AddIterator(kw_iterator);
}
for (const auto& vector_iter : vector_iterators) {
vector_iter->seal();
// Cast to ChunkMergeIterator to call seal
auto chunk_merge_iter =
std::static_pointer_cast<ChunkMergeIterator>(vector_iter);
chunk_merge_iter->seal();
}
this->vector_iterators_ = vector_iterators;
}
@ -275,6 +301,28 @@ struct SearchResult {
vector_iterators_;
// record the storage usage in search
StorageCost search_storage_cost_;
bool element_level_{false};
std::vector<int32_t> element_indices_;
std::optional<std::vector<std::shared_ptr<VectorIterator>>>
element_iterators_;
std::shared_ptr<const IArrayOffsets> array_offsets_{nullptr};
std::vector<std::unique_ptr<uint8_t[]>> chunk_buffers_{};
bool
HasIterators() const {
return (element_level_ && element_iterators_.has_value()) ||
(!element_level_ && vector_iterators_.has_value());
}
std::optional<std::vector<std::shared_ptr<VectorIterator>>>
GetIterators() {
if (element_level_) {
return element_iterators_;
} else {
return vector_iterators_;
}
}
};
using SearchResultPtr = std::shared_ptr<SearchResult>;

View File

@ -171,4 +171,15 @@ Schema::MmapEnabled(const FieldId& field_id) const {
return {true, it->second};
}
const FieldMeta&
Schema::GetFirstArrayFieldInStruct(const std::string& struct_name) const {
auto cache_it = struct_array_field_cache_.find(struct_name);
if (cache_it != struct_array_field_cache_.end()) {
return fields_.at(cache_it->second);
}
ThrowInfo(ErrorCode::UnexpectedError,
"No array field found in struct: {}",
struct_name);
}
} // namespace milvus

View File

@ -367,7 +367,24 @@ class Schema {
name_ids_.emplace(field_name, field_id);
id_names_.emplace(field_id, field_name);
fields_.emplace(field_id, field_meta);
// Build struct_array_field_cache_ for ARRAY/VECTOR_ARRAY fields
// Field name format: "struct_name[0].field_name"
auto data_type = field_meta.get_data_type();
if (data_type == DataType::ARRAY ||
data_type == DataType::VECTOR_ARRAY) {
const std::string& name_str = field_name.get();
auto bracket_pos = name_str.find('[');
if (bracket_pos != std::string::npos && bracket_pos > 0) {
std::string struct_name = name_str.substr(0, bracket_pos);
// Only cache the first array field for each struct
if (struct_array_field_cache_.find(struct_name) ==
struct_array_field_cache_.end()) {
struct_array_field_cache_[struct_name] = field_id;
}
}
}
fields_.emplace(field_id, std::move(field_meta));
field_ids_.emplace_back(field_id);
}
@ -392,6 +409,10 @@ class Schema {
std::pair<bool, bool>
MmapEnabled(const FieldId& field) const;
// Find the first array field belonging to a struct (cached)
const FieldMeta&
GetFirstArrayFieldInStruct(const std::string& struct_name) const;
private:
int64_t debug_id = START_USER_FIELDID;
std::vector<FieldId> field_ids_;
@ -418,6 +439,9 @@ class Schema {
bool has_mmap_setting_ = false;
bool mmap_enabled_ = false;
std::unordered_map<FieldId, bool> mmap_fields_;
// Cache for struct_name -> first array field mapping (built during AddField)
std::unordered_map<std::string, FieldId> struct_array_field_cache_;
};
using SchemaPtr = std::shared_ptr<Schema>;

View File

@ -28,6 +28,7 @@
#include "common/Consts.h"
#include "common/FieldMeta.h"
#include "common/LoadInfo.h"
#include "common/Schema.h"
#include "common/Types.h"
#include "common/EasyAssert.h"
#include "knowhere/dataset.h"

View File

@ -458,6 +458,11 @@ class VectorArrayView {
}
}
int
length() const {
return length_;
}
private:
char* data_{nullptr};
int64_t dim_ = 0;

View File

@ -32,6 +32,8 @@
#include "exec/operator/VectorSearchNode.h"
#include "exec/operator/RandomSampleNode.h"
#include "exec/operator/GroupByNode.h"
#include "exec/operator/ElementFilterNode.h"
#include "exec/operator/ElementFilterBitsNode.h"
#include "exec/Task.h"
#include "plan/PlanNode.h"
@ -104,6 +106,19 @@ DriverFactory::CreateDriver(std::unique_ptr<DriverContext> ctx,
tracer::AddEvent("create_operator: RescoresNode");
operators.push_back(
std::make_unique<PhyRescoresNode>(id, ctx.get(), rescoresnode));
} else if (auto node =
std::dynamic_pointer_cast<const plan::ElementFilterNode>(
plannode)) {
tracer::AddEvent("create_operator: ElementFilterNode");
operators.push_back(
std::make_unique<PhyElementFilterNode>(id, ctx.get(), node));
} else if (auto node = std::dynamic_pointer_cast<
const plan::ElementFilterBitsNode>(plannode)) {
tracer::AddEvent("create_operator: ElementFilterBitsNode");
operators.push_back(std::make_unique<PhyElementFilterBitsNode>(
id, ctx.get(), node));
} else {
ThrowInfo(ErrorCode::UnexpectedError, "Unknown plan node type");
}
// TODO: add more operators
}

View File

@ -28,6 +28,7 @@
#include "common/Common.h"
#include "common/Types.h"
#include "common/Exception.h"
#include "common/ArrayOffsets.h"
#include "common/OpContext.h"
#include "segcore/SegmentInterface.h"
@ -303,6 +304,64 @@ class QueryContext : public Context {
return plan_options_;
}
void
set_element_level_query(bool element_level) {
element_level_query_ = element_level;
}
bool
element_level_query() const {
return element_level_query_;
}
void
set_struct_name(const std::string& field_name) {
struct_name_ = field_name;
}
const std::string&
get_struct_name() const {
return struct_name_;
}
void
set_array_offsets(std::shared_ptr<const IArrayOffsets> offsets) {
array_offsets_ = std::move(offsets);
}
std::shared_ptr<const IArrayOffsets>
get_array_offsets() const {
return array_offsets_;
}
void
set_active_element_count(int64_t count) {
active_element_count_ = count;
}
int64_t
get_active_element_count() const {
return active_element_count_;
}
void
set_element_level_bitset(TargetBitmap&& bitset) {
element_level_bitset_ = std::move(bitset);
}
std::optional<TargetBitmap>
get_element_level_bitset() {
if (element_level_bitset_.has_value()) {
return std::move(element_level_bitset_.value());
}
return std::nullopt;
}
bool
has_element_level_bitset() const {
return element_level_bitset_.has_value();
}
private:
folly::Executor* executor_;
//folly::Executor::KeepAlive<> executor_keepalive_;
@ -331,6 +390,12 @@ class QueryContext : public Context {
int32_t consistency_level_ = 0;
query::PlanOptions plan_options_;
bool element_level_query_{false};
std::string struct_name_;
std::shared_ptr<const IArrayOffsets> array_offsets_{nullptr};
int64_t active_element_count_{0}; // Total elements in active documents
std::optional<TargetBitmap> element_level_bitset_;
};
// Represent the state of one thread of query execution.

View File

@ -29,7 +29,11 @@ PhyBinaryArithOpEvalRangeExpr::Eval(EvalCtx& context, VectorPtr& result) {
auto input = context.get_offset_input();
SetHasOffsetInput((input != nullptr));
switch (expr_->column_.data_type_) {
auto data_type = expr_->column_.data_type_;
if (expr_->column_.element_level_) {
data_type = expr_->column_.element_type_;
}
switch (data_type) {
case DataType::BOOL: {
result = ExecRangeVisitorImpl<bool>(input);
break;
@ -1840,14 +1844,27 @@ PhyBinaryArithOpEvalRangeExpr::ExecRangeVisitorImplForData(
int64_t processed_size;
if (has_offset_input_) {
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
value,
right_operand);
if (expr_->column_.element_level_) {
// For element-level filtering
processed_size = ProcessElementLevelByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
value,
right_operand);
} else {
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
value,
right_operand);
}
} else {
AssertInfo(!expr_->column_.element_level_,
"Element-level filtering is not supported without offsets");
processed_size = ProcessDataChunks<T>(execute_sub_batch,
skip_index_func,
res,

View File

@ -31,7 +31,12 @@ PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
auto input = context.get_offset_input();
SetHasOffsetInput((input != nullptr));
switch (expr_->column_.data_type_) {
auto data_type = expr_->column_.data_type_;
if (expr_->column_.element_level_) {
data_type = expr_->column_.element_type_;
}
switch (data_type) {
case DataType::BOOL: {
result = ExecRangeVisitorImpl<bool>(context);
break;
@ -414,14 +419,28 @@ PhyBinaryRangeFilterExpr::ExecRangeVisitorImplForData(EvalCtx& context) {
};
int64_t processed_size;
if (has_offset_input_) {
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
val1,
val2);
if (expr_->column_.element_level_) {
// For element-level filtering
processed_size = ProcessElementLevelByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
val1,
val2);
} else {
// For doc-level filtering
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
val1,
val2);
}
} else {
AssertInfo(!expr_->column_.element_level_,
"Element-level filtering is not supported without offsets");
processed_size = ProcessDataChunks<T>(
execute_sub_batch, skip_index_func, res, valid_res, val1, val2);
}

View File

@ -71,6 +71,9 @@ PhyColumnExpr::Eval(EvalCtx& context, VectorPtr& result) {
template <typename T>
VectorPtr
PhyColumnExpr::DoEval(OffsetVector* input) {
AssertInfo(!expr_->GetColumn().element_level_,
"ColumnExpr of row-level access is not supported");
// similar to PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op)
// take offsets as input
if (has_offset_input_) {

View File

@ -21,6 +21,8 @@
#include <string>
#include <type_traits>
#include "common/Array.h"
#include "common/ArrayOffsets.h"
#include "common/FieldDataInterface.h"
#include "common/Json.h"
#include "common/OpContext.h"
@ -679,6 +681,173 @@ class SegmentExpr : public Expr {
return input->size();
}
// Process element-level data by element IDs
// Handles the type mismatch between storage (ArrayView) and element type
// Currently only implemented for sealed chunked segments
template <typename ElementType, typename FUNC, typename... ValTypes>
int64_t
ProcessElementLevelByOffsets(
FUNC func,
std::function<bool(const milvus::SkipIndex&, FieldId, int)> skip_func,
OffsetVector* element_ids,
TargetBitmapView res,
TargetBitmapView valid_res,
const ValTypes&... values) {
auto& skip_index = segment_->GetSkipIndex();
if (segment_->type() == SegmentType::Sealed) {
AssertInfo(
segment_->is_chunked(),
"Element-level filtering requires chunked segment for sealed");
auto array_offsets = segment_->GetArrayOffsets(field_id_);
if (!array_offsets) {
ThrowInfo(ErrorCode::UnexpectedError,
"IArrayOffsets not found for field {}",
field_id_.get());
}
// Batch process consecutive elements belonging to the same chunk
size_t processed_size = 0;
size_t i = 0;
// Reuse these vectors to avoid repeated heap allocations
FixedVector<int32_t> offsets;
FixedVector<int32_t> elem_indices;
while (i < element_ids->size()) {
// Start of a new chunk batch
int64_t element_id = (*element_ids)[i];
auto [doc_id, elem_idx] =
array_offsets->ElementIDToRowID(element_id);
auto [chunk_id, chunk_offset] =
segment_->get_chunk_by_offset(field_id_, doc_id);
// Collect consecutive elements belonging to the same chunk
offsets.clear();
elem_indices.clear();
offsets.push_back(chunk_offset);
elem_indices.push_back(elem_idx);
size_t batch_start = i;
i++;
// Look ahead for more elements in the same chunk
while (i < element_ids->size()) {
int64_t next_element_id = (*element_ids)[i];
auto [next_doc_id, next_elem_idx] =
array_offsets->ElementIDToRowID(next_element_id);
auto [next_chunk_id, next_chunk_offset] =
segment_->get_chunk_by_offset(field_id_, next_doc_id);
if (next_chunk_id != chunk_id) {
break; // Different chunk, process current batch
}
offsets.push_back(next_chunk_offset);
elem_indices.push_back(next_elem_idx);
i++;
}
// Batch fetch all ArrayViews for this chunk
auto pw = segment_->get_views_by_offsets<ArrayView>(
op_ctx_, field_id_, chunk_id, offsets);
auto [array_vec, valid_data] = pw.get();
// Process each element in this batch
for (size_t j = 0; j < offsets.size(); j++) {
size_t result_idx = batch_start + j;
if ((!skip_func ||
!skip_func(skip_index, field_id_, chunk_id)) &&
(!namespace_skip_func_.has_value() ||
!namespace_skip_func_.value()(chunk_id))) {
// Extract element from ArrayView
auto value =
array_vec[j].template get_data<ElementType>(
elem_indices[j]);
bool is_valid = !valid_data.data() || valid_data[j];
func.template operator()<FilterType::random>(
&value,
&is_valid,
nullptr,
1,
res + result_idx,
valid_res + result_idx,
values...);
} else {
// Chunk is skipped - handle exactly like ProcessDataByOffsets
if (valid_data.size() > j && !valid_data[j]) {
res[result_idx] = valid_res[result_idx] = false;
}
}
processed_size++;
}
}
return processed_size;
} else {
auto array_offsets = segment_->GetArrayOffsets(field_id_);
if (!array_offsets) {
ThrowInfo(ErrorCode::UnexpectedError,
"IArrayOffsets not found for field {}",
field_id_.get());
}
auto& skip_index = segment_->GetSkipIndex();
size_t processed_size = 0;
for (size_t i = 0; i < element_ids->size(); i++) {
int64_t element_id = (*element_ids)[i];
auto [doc_id, elem_idx] =
array_offsets->ElementIDToRowID(element_id);
// Calculate chunk_id and chunk_offset for this doc
auto chunk_id = doc_id / size_per_chunk_;
auto chunk_offset = doc_id % size_per_chunk_;
// Get the Array chunk (Growing segment stores Array, not ArrayView)
auto pw =
segment_->chunk_data<Array>(op_ctx_, field_id_, chunk_id);
auto chunk = pw.get();
const Array* array_ptr = chunk.data() + chunk_offset;
const bool* valid_data = chunk.valid_data();
if (valid_data != nullptr) {
valid_data += chunk_offset;
}
if ((!skip_func ||
!skip_func(skip_index, field_id_, chunk_id)) &&
(!namespace_skip_func_.has_value() ||
!namespace_skip_func_.value()(chunk_id))) {
// Extract element from Array
auto value = array_ptr->get_data<ElementType>(elem_idx);
bool is_valid = !valid_data || valid_data[0];
func.template operator()<FilterType::random>(
&value,
&is_valid,
nullptr,
1,
res + processed_size,
valid_res + processed_size,
values...);
} else {
// Chunk is skipped
if (valid_data && !valid_data[0]) {
res[processed_size] = valid_res[processed_size] = false;
}
}
processed_size++;
}
return processed_size;
}
}
// Template parameter to control whether segment offsets are needed (for GIS functions)
template <typename T,
bool NeedSegmentOffsets = false,

View File

@ -35,7 +35,11 @@ PhyTermFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
result = ExecPkTermImpl();
return;
}
switch (expr_->column_.data_type_) {
auto data_type = expr_->column_.data_type_;
if (expr_->column_.element_level_) {
data_type = expr_->column_.element_type_;
}
switch (data_type) {
case DataType::BOOL: {
result = ExecVisitorImpl<bool>(context);
break;
@ -1001,13 +1005,25 @@ PhyTermFilterExpr::ExecVisitorImplForData(EvalCtx& context) {
int64_t processed_size;
if (has_offset_input_) {
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
arg_set_);
if (expr_->column_.element_level_) {
// For element-level filtering
processed_size = ProcessElementLevelByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
arg_set_);
} else {
processed_size = ProcessDataByOffsets<T>(execute_sub_batch,
skip_index_func,
input,
res,
valid_res,
arg_set_);
}
} else {
AssertInfo(!expr_->column_.element_level_,
"Element-level filtering is not supported without offsets");
processed_size = ProcessDataChunks<T>(
execute_sub_batch, skip_index_func, res, valid_res, arg_set_);
}

View File

@ -163,7 +163,11 @@ PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
auto input = context.get_offset_input();
SetHasOffsetInput((input != nullptr));
switch (expr_->column_.data_type_) {
auto data_type = expr_->column_.data_type_;
if (expr_->column_.element_level_) {
data_type = expr_->column_.element_type_;
}
switch (data_type) {
case DataType::BOOL: {
result = ExecRangeVisitorImpl<bool>(context);
break;
@ -1713,9 +1717,17 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplForData(EvalCtx& context) {
int64_t processed_size;
if (has_offset_input_) {
processed_size = ProcessDataByOffsets<T>(
execute_sub_batch, skip_index_func, input, res, valid_res, val);
if (expr_->column_.element_level_) {
// For element-level filtering
processed_size = ProcessElementLevelByOffsets<T>(
execute_sub_batch, skip_index_func, input, res, valid_res, val);
} else {
processed_size = ProcessDataByOffsets<T>(
execute_sub_batch, skip_index_func, input, res, valid_res, val);
}
} else {
AssertInfo(!expr_->column_.element_level_,
"Element-level filtering is not supported without offsets");
processed_size = ProcessDataChunks<T>(
execute_sub_batch, skip_index_func, res, valid_res, val);
}

View File

@ -0,0 +1,215 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ElementFilterBitsNode.h"
#include "common/Tracer.h"
#include "fmt/format.h"
#include "monitor/Monitor.h"
namespace milvus {
namespace exec {
PhyElementFilterBitsNode::PhyElementFilterBitsNode(
int32_t operator_id,
DriverContext* driverctx,
const std::shared_ptr<const plan::ElementFilterBitsNode>&
element_filter_bits_node)
: Operator(driverctx,
DataType::NONE,
operator_id,
"element_filter_bits_plan_node",
"PhyElementFilterBitsNode"),
struct_name_(element_filter_bits_node->struct_name()) {
ExecContext* exec_context = operator_context_->get_exec_context();
query_context_ = exec_context->get_query_context();
// Build expression set from element-level expression
std::vector<expr::TypedExprPtr> exprs;
exprs.emplace_back(element_filter_bits_node->element_filter());
element_exprs_ = std::make_unique<ExprSet>(exprs, exec_context);
}
void
PhyElementFilterBitsNode::AddInput(RowVectorPtr& input) {
input_ = std::move(input);
}
RowVectorPtr
PhyElementFilterBitsNode::GetOutput() {
if (is_finished_ || input_ == nullptr) {
return nullptr;
}
DeferLambda([&]() { is_finished_ = true; });
tracer::AutoSpan span(
"PhyElementFilterBitsNode::GetOutput", tracer::GetRootSpan(), true);
std::chrono::high_resolution_clock::time_point start_time =
std::chrono::high_resolution_clock::now();
std::chrono::high_resolution_clock::time_point step_time;
// Step 1: Get array offsets
auto segment = query_context_->get_segment();
auto& field_meta =
segment->get_schema().GetFirstArrayFieldInStruct(struct_name_);
auto field_id = field_meta.get_id();
auto array_offsets = segment->GetArrayOffsets(field_id);
if (array_offsets == nullptr) {
ThrowInfo(ErrorCode::UnexpectedError,
"IArrayOffsets not found for field {}",
field_id.get());
}
query_context_->set_array_offsets(array_offsets);
auto [first_elem, _] =
array_offsets->ElementIDRangeOfRow(query_context_->get_active_count());
query_context_->set_active_element_count(first_elem);
// Step 2: Prepare doc bitset
auto col_input = GetColumnVector(input_);
TargetBitmapView doc_bitset(col_input->GetRawData(), col_input->size());
TargetBitmapView doc_bitset_valid(col_input->GetValidRawData(),
col_input->size());
doc_bitset.flip();
// Step 3: Convert doc bitset to element offsets
FixedVector<int32_t> element_offsets =
DocBitsetToElementOffsets(doc_bitset);
// Step 4: Evaluate element expression
auto [expr_result, valid_expr_result] =
EvaluateElementExpression(element_offsets);
// Step 5: Set query context
query_context_->set_element_level_query(true);
query_context_->set_struct_name(struct_name_);
std::chrono::high_resolution_clock::time_point end_time =
std::chrono::high_resolution_clock::now();
double total_cost =
std::chrono::duration<double, std::micro>(end_time - start_time)
.count();
milvus::monitor::internal_core_search_latency_scalar.Observe(total_cost /
1000);
auto filtered_count = expr_result.count();
tracer::AddEvent(
fmt::format("struct_name: {}, total_elements: {}, output_rows: {}, "
"filtered: {}, cost_us: {}",
struct_name_,
array_offsets->GetTotalElementCount(),
array_offsets->GetTotalElementCount() - filtered_count,
filtered_count,
total_cost));
std::vector<VectorPtr> col_res;
col_res.push_back(std::make_shared<ColumnVector>(
std::move(expr_result), std::move(valid_expr_result)));
return std::make_shared<RowVector>(col_res);
}
FixedVector<int32_t>
PhyElementFilterBitsNode::DocBitsetToElementOffsets(
const TargetBitmapView& doc_bitset) {
auto array_offsets = query_context_->get_array_offsets();
AssertInfo(array_offsets != nullptr, "Array offsets not available");
int64_t doc_count = array_offsets->GetRowCount();
AssertInfo(doc_bitset.size() == doc_count,
"Doc bitset size mismatch: {} vs {}",
doc_bitset.size(),
doc_count);
FixedVector<int32_t> element_offsets;
element_offsets.reserve(array_offsets->GetTotalElementCount());
// For each document that passes the filter, get all its element offsets
for (int64_t doc_id = 0; doc_id < doc_count; ++doc_id) {
if (doc_bitset[doc_id]) {
// Get element range for this document
auto [first_elem, last_elem] =
array_offsets->ElementIDRangeOfRow(doc_id);
// Add all element IDs for this document
for (int64_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
element_offsets.push_back(static_cast<int32_t>(elem_id));
}
}
}
return element_offsets;
}
std::pair<TargetBitmap, TargetBitmap>
PhyElementFilterBitsNode::EvaluateElementExpression(
FixedVector<int32_t>& element_offsets) {
tracer::AutoSpan span("PhyElementFilterBitsNode::EvaluateElementExpression",
tracer::GetRootSpan(),
true);
tracer::AddEvent(fmt::format("input_elements: {}", element_offsets.size()));
// Use offset interface by passing element_offsets as third parameter
EvalCtx eval_ctx(operator_context_->get_exec_context(),
element_exprs_.get(),
&element_offsets);
std::vector<VectorPtr> results;
element_exprs_->Eval(0, 1, true, eval_ctx, results);
AssertInfo(results.size() == 1 && results[0] != nullptr,
"ElementFilterBitsNode: expression evaluation should return "
"exactly one result");
TargetBitmap bitset;
TargetBitmap valid_bitset;
int64_t total_elements = query_context_->get_active_element_count();
bitset = TargetBitmap(total_elements, false);
valid_bitset = TargetBitmap(total_elements, true);
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(results[0]);
if (!col_vec) {
ThrowInfo(ExprInvalid,
"ElementFilterBitsNode result should be ColumnVector");
}
if (!col_vec->IsBitmap()) {
ThrowInfo(ExprInvalid, "ElementFilterBitsNode result should be bitmap");
}
auto col_vec_size = col_vec->size();
TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size);
AssertInfo(col_vec_size == element_offsets.size(),
"ElementFilterBitsNode result size mismatch: {} vs {}",
col_vec_size,
element_offsets.size());
for (size_t i = 0; i < element_offsets.size(); ++i) {
if (bitsetview[i]) {
bitset[element_offsets[i]] = true;
}
}
bitset.flip();
tracer::AddEvent(fmt::format("evaluated_elements: {}, total_elements: {}",
element_offsets.size(),
total_elements));
return std::make_pair(std::move(bitset), std::move(valid_bitset));
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,93 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "exec/Driver.h"
#include "exec/expression/Expr.h"
#include "exec/operator/Operator.h"
#include "exec/QueryContext.h"
#include "common/ArrayOffsets.h"
#include "expr/ITypeExpr.h"
namespace milvus {
namespace exec {
class PhyElementFilterBitsNode : public Operator {
public:
PhyElementFilterBitsNode(
int32_t operator_id,
DriverContext* ctx,
const std::shared_ptr<const plan::ElementFilterBitsNode>&
element_filter_bits_node);
bool
IsFilter() override {
return true;
}
bool
NeedInput() const override {
return !input_;
}
void
AddInput(RowVectorPtr& input) override;
RowVectorPtr
GetOutput() override;
bool
IsFinished() override {
return is_finished_;
}
void
Close() override {
Operator::Close();
if (element_exprs_) {
element_exprs_->Clear();
}
}
BlockingReason
IsBlocked(ContinueFuture* /* unused */) override {
return BlockingReason::kNotBlocked;
}
std::string
ToString() const override {
return "PhyElementFilterBitsNode";
}
private:
FixedVector<int32_t>
DocBitsetToElementOffsets(const TargetBitmapView& doc_bitset);
std::pair<TargetBitmap, TargetBitmap>
EvaluateElementExpression(FixedVector<int32_t>& element_offsets);
std::unique_ptr<ExprSet> element_exprs_;
QueryContext* query_context_;
std::string struct_name_;
bool is_finished_{false};
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,129 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ElementFilterNode.h"
#include "common/Tracer.h"
#include "common/ElementFilterIterator.h"
#include "monitor/Monitor.h"
namespace milvus {
namespace exec {
PhyElementFilterNode::PhyElementFilterNode(
int32_t operator_id,
DriverContext* driverctx,
const std::shared_ptr<const plan::ElementFilterNode>& element_filter_node)
: Operator(driverctx,
element_filter_node->output_type(),
operator_id,
element_filter_node->id(),
"PhyElementFilterNode"),
struct_name_(element_filter_node->struct_name()) {
ExecContext* exec_context = operator_context_->get_exec_context();
query_context_ = exec_context->get_query_context();
std::vector<expr::TypedExprPtr> exprs;
exprs.emplace_back(element_filter_node->element_filter());
element_exprs_ = std::make_unique<ExprSet>(exprs, exec_context);
}
void
PhyElementFilterNode::AddInput(RowVectorPtr& input) {
input_ = std::move(input);
}
RowVectorPtr
PhyElementFilterNode::GetOutput() {
if (is_finished_ || !no_more_input_) {
return nullptr;
}
tracer::AutoSpan span(
"PhyElementFilterNode::GetOutput", tracer::GetRootSpan(), true);
DeferLambda([&]() { is_finished_ = true; });
if (input_ == nullptr) {
return nullptr;
}
std::chrono::high_resolution_clock::time_point start_time =
std::chrono::high_resolution_clock::now();
// Step 1: Get search result with iterators
milvus::SearchResult search_result = query_context_->get_search_result();
if (!search_result.element_level_) {
ThrowInfo(ExprInvalid,
"PhyElementFilterNode expects element-level search result");
}
if (!search_result.vector_iterators_.has_value()) {
ThrowInfo(
ExprInvalid,
"PhyElementFilterNode expects vector_iterators in search result");
}
auto segment = query_context_->get_segment();
auto& field_meta =
segment->get_schema().GetFirstArrayFieldInStruct(struct_name_);
auto field_id = field_meta.get_id();
auto array_offsets = segment->GetArrayOffsets(field_id);
if (array_offsets == nullptr) {
ThrowInfo(ErrorCode::UnexpectedError,
"IArrayOffsets not found for field {}",
field_id.get());
}
query_context_->set_array_offsets(array_offsets);
// Step 2: Wrap each iterator with ElementFilterIterator
auto& base_iterators = search_result.vector_iterators_.value();
std::vector<std::shared_ptr<VectorIterator>> wrapped_iterators;
wrapped_iterators.reserve(base_iterators.size());
ExecContext* exec_context = operator_context_->get_exec_context();
for (auto& base_iter : base_iterators) {
// Wrap each iterator with ElementFilterIterator
auto wrapped_iter = std::make_shared<ElementFilterIterator>(
base_iter, exec_context, element_exprs_.get());
wrapped_iterators.push_back(wrapped_iter);
}
// Step 3: Update search result with wrapped iterators
search_result.vector_iterators_ = std::move(wrapped_iterators);
query_context_->set_search_result(std::move(search_result));
// Step 4: Record metrics
std::chrono::high_resolution_clock::time_point end_time =
std::chrono::high_resolution_clock::now();
double cost =
std::chrono::duration<double, std::micro>(end_time - start_time)
.count();
tracer::AddEvent(
fmt::format("PhyElementFilterNode: wrapped {} iterators, struct_name: "
"{}, cost_us: {}",
wrapped_iterators.size(),
struct_name_,
cost));
// Pass through input to downstream
return input_;
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,87 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "exec/Driver.h"
#include "exec/expression/Expr.h"
#include "exec/operator/Operator.h"
#include "exec/QueryContext.h"
#include "common/ArrayOffsets.h"
#include "expr/ITypeExpr.h"
namespace milvus {
namespace exec {
class PhyElementFilterNode : public Operator {
public:
PhyElementFilterNode(int32_t operator_id,
DriverContext* ctx,
const std::shared_ptr<const plan::ElementFilterNode>&
element_filter_node);
bool
IsFilter() override {
return true;
}
bool
NeedInput() const override {
return !is_finished_;
}
void
AddInput(RowVectorPtr& input) override;
RowVectorPtr
GetOutput() override;
bool
IsFinished() override {
return is_finished_;
}
void
Close() override {
Operator::Close();
if (element_exprs_) {
element_exprs_->Clear();
}
}
BlockingReason
IsBlocked(ContinueFuture* /* unused */) override {
return BlockingReason::kNotBlocked;
}
std::string
ToString() const override {
return "PhyElementFilterNode";
}
private:
std::unique_ptr<ExprSet> element_exprs_;
QueryContext* query_context_;
std::string struct_name_;
bool is_finished_{false};
};
} // namespace exec
} // namespace milvus

View File

@ -88,7 +88,8 @@ insert_helper(milvus::SearchResult& search_result,
const FixedVector<int32_t>& offsets,
const int64_t nq_index,
const int64_t unity_topk,
const int i) {
const int i,
const IArrayOffsets* array_offsets = nullptr) {
auto pos = large_is_better
? find_binsert_position<true>(search_result.distances_,
nq_index * unity_topk,
@ -98,6 +99,18 @@ insert_helper(milvus::SearchResult& search_result,
nq_index * unity_topk,
nq_index * unity_topk + topk,
distances[i]);
// For element-level: convert element_id to (doc_id, element_index)
int64_t doc_id;
int32_t elem_idx = -1;
if (array_offsets != nullptr) {
auto [doc, idx] = array_offsets->ElementIDToRowID(offsets[i]);
doc_id = doc;
elem_idx = idx;
} else {
doc_id = offsets[i];
}
if (topk > pos) {
std::memmove(&search_result.distances_[pos + 1],
&search_result.distances_[pos],
@ -105,8 +118,16 @@ insert_helper(milvus::SearchResult& search_result,
std::memmove(&search_result.seg_offsets_[pos + 1],
&search_result.seg_offsets_[pos],
(topk - pos) * sizeof(int64_t));
if (array_offsets != nullptr) {
std::memmove(&search_result.element_indices_[pos + 1],
&search_result.element_indices_[pos],
(topk - pos) * sizeof(int32_t));
}
}
search_result.seg_offsets_[pos] = doc_id;
if (array_offsets != nullptr) {
search_result.element_indices_[pos] = elem_idx;
}
search_result.seg_offsets_[pos] = offsets[i];
search_result.distances_[pos] = distances[i];
++topk;
}
@ -178,10 +199,31 @@ PhyIterativeFilterNode::GetOutput() {
search_result.total_nq_,
"Vector Iterators' count must be equal to total_nq_, Check "
"your code");
bool element_level = search_result.element_level_;
auto array_offsets = query_context_->get_array_offsets();
// For element-level, we need array_offsets to convert element_id → doc_id
if (element_level) {
AssertInfo(
array_offsets != nullptr,
"Array offsets required for element-level iterative filter");
}
int nq_index = 0;
search_result.seg_offsets_.resize(nq * unity_topk, INVALID_SEG_OFFSET);
search_result.distances_.resize(nq * unity_topk);
if (element_level) {
search_result.element_indices_.resize(nq * unity_topk, -1);
}
// Reuse memory allocation across batches and nqs
FixedVector<int32_t> doc_offsets;
std::vector<int64_t> element_to_doc_mapping;
std::unordered_map<int64_t, bool> doc_eval_cache;
std::unordered_set<int64_t> unique_doc_ids;
for (auto& iterator : search_result.vector_iterators_.value()) {
EvalCtx eval_ctx(operator_context_->get_exec_context(),
exprs_.get());
@ -208,8 +250,35 @@ PhyIterativeFilterNode::GetOutput() {
break;
}
}
// Clear but retain capacity
doc_offsets.clear();
element_to_doc_mapping.clear();
doc_eval_cache.clear();
unique_doc_ids.clear();
if (element_level) {
// 1. Convert element_ids to doc_ids and do filter on those doc_ids
// 2. element_ids with doc_ids that pass the filter are what we interested in
element_to_doc_mapping.reserve(offsets.size());
for (auto element_id : offsets) {
auto [doc_id, elem_index] =
array_offsets->ElementIDToRowID(element_id);
element_to_doc_mapping.push_back(doc_id);
unique_doc_ids.insert(doc_id);
}
doc_offsets.reserve(unique_doc_ids.size());
for (auto doc_id : unique_doc_ids) {
doc_offsets.emplace_back(static_cast<int32_t>(doc_id));
}
} else {
doc_offsets = offsets;
}
if (is_native_supported_) {
eval_ctx.set_offset_input(&offsets);
eval_ctx.set_offset_input(&doc_offsets);
std::vector<VectorPtr> results;
exprs_->Eval(0, 1, true, eval_ctx, results);
AssertInfo(
@ -223,24 +292,52 @@ PhyIterativeFilterNode::GetOutput() {
auto col_vec_size = col_vec->size();
TargetBitmapView bitsetview(col_vec->GetRawData(),
col_vec_size);
Assert(bitsetview.size() <= batch_size);
Assert(bitsetview.size() == offsets.size());
for (auto i = 0; i < offsets.size(); ++i) {
if (bitsetview[i] > 0) {
insert_helper(search_result,
topk,
large_is_better,
distances,
offsets,
nq_index,
unity_topk,
i);
if (topk == unity_topk) {
break;
if (element_level) {
Assert(bitsetview.size() == doc_offsets.size());
for (size_t i = 0; i < doc_offsets.size(); ++i) {
doc_eval_cache[doc_offsets[i]] =
(bitsetview[i] > 0);
}
for (size_t i = 0; i < offsets.size(); ++i) {
int64_t doc_id = element_to_doc_mapping[i];
if (doc_eval_cache[doc_id]) {
insert_helper(search_result,
topk,
large_is_better,
distances,
offsets,
nq_index,
unity_topk,
i,
array_offsets.get());
if (topk == unity_topk) {
break;
}
}
}
} else {
Assert(bitsetview.size() <= batch_size);
Assert(bitsetview.size() == offsets.size());
for (auto i = 0; i < offsets.size(); ++i) {
if (bitsetview[i]) {
insert_helper(search_result,
topk,
large_is_better,
distances,
offsets,
nq_index,
unity_topk,
i);
if (topk == unity_topk) {
break;
}
}
}
}
} else {
Assert(!element_level);
for (auto i = 0; i < offsets.size(); ++i) {
if (bitset[offsets[i]] > 0) {
insert_helper(search_result,

View File

@ -17,6 +17,8 @@
#include "VectorSearchNode.h"
#include "common/Tracer.h"
#include "fmt/format.h"
#include "common/ArrayOffsets.h"
#include "exec/operator/Utils.h"
#include "monitor/Monitor.h"
namespace milvus {
@ -80,10 +82,43 @@ PhyVectorSearchNode::GetOutput() {
auto src_data = ph.get_blob();
auto src_offsets = ph.get_offsets();
auto num_queries = ph.num_of_queries_;
std::shared_ptr<const IArrayOffsets> array_offsets = nullptr;
if (ph.element_level_) {
array_offsets = segment_->GetArrayOffsets(search_info_.field_id_);
AssertInfo(array_offsets != nullptr, "Array offsets not available");
query_context_->set_array_offsets(array_offsets);
search_info_.array_offsets_ = array_offsets;
}
// There are two types of execution: pre-filter and iterative filter
// For **pre-filter**, we have execution path: FilterBitsNode -> MvccNode -> ElementFilterBitsNode -> VectorSearchNode -> ...
// For **iterative filter**, we have execution path: MvccNode -> VectorSearchNode -> ElementFilterNode -> FilterNode -> ...
//
// When embedding search embedding on embedding list is used, which means element_level_ is true, we need to transform doc-level
// bitset to element-level bitset. In pre-filter path, ElementFilterBitsNode already transforms the bitset. We need to transform it
// in iterative filter path.
if (milvus::exec::UseVectorIterator(search_info_) && ph.element_level_) {
auto col_input = GetColumnVector(input_);
TargetBitmapView view(col_input->GetRawData(), col_input->size());
TargetBitmapView valid_view(col_input->GetValidRawData(),
col_input->size());
auto [element_bitset, valid_element_bitset] =
array_offsets->RowBitsetToElementBitset(view, valid_view);
query_context_->set_active_element_count(element_bitset.size());
std::vector<VectorPtr> col_res;
col_res.push_back(std::make_shared<ColumnVector>(
std::move(element_bitset), std::move(valid_element_bitset)));
input_ = std::make_shared<RowVector>(col_res);
}
milvus::SearchResult search_result;
auto col_input = GetColumnVector(input_);
TargetBitmapView view(col_input->GetRawData(), col_input->size());
if (view.all()) {
query_context_->set_search_result(
std::move(empty_search_result(num_queries)));
@ -94,6 +129,7 @@ PhyVectorSearchNode::GetOutput() {
milvus::BitsetView final_view((uint8_t*)col_input->GetRawData(),
col_input->size());
auto op_context = query_context_->get_op_context();
// todo(SpadeA): need to pass element_level to make check more rigorously?
segment_->vector_search(search_info_,
src_data,
src_offsets,
@ -104,6 +140,7 @@ PhyVectorSearchNode::GetOutput() {
search_result);
search_result.total_data_cnt_ = final_view.size();
search_result.element_level_ = ph.element_level_;
span.GetSpan()->SetAttribute(
"result_count", static_cast<int>(search_result.seg_offsets_.size()));

View File

@ -118,6 +118,7 @@ struct ColumnInfo {
DataType element_type_;
std::vector<std::string> nested_path_;
bool nullable_;
bool element_level_;
ColumnInfo(const proto::plan::ColumnInfo& column_info)
: field_id_(column_info.field_id()),
@ -125,7 +126,8 @@ struct ColumnInfo {
element_type_(static_cast<DataType>(column_info.element_type())),
nested_path_(column_info.nested_path().begin(),
column_info.nested_path().end()),
nullable_(column_info.nullable()) {
nullable_(column_info.nullable()),
element_level_(column_info.is_element_level()) {
}
ColumnInfo(FieldId field_id,
@ -136,7 +138,8 @@ struct ColumnInfo {
data_type_(data_type),
element_type_(DataType::NONE),
nested_path_(std::move(nested_path)),
nullable_(nullable) {
nullable_(nullable),
element_level_(false) {
}
ColumnInfo(FieldId field_id,
@ -148,7 +151,8 @@ struct ColumnInfo {
data_type_(data_type),
element_type_(element_type),
nested_path_(std::move(nested_path)),
nullable_(nullable) {
nullable_(nullable),
element_level_(false) {
}
bool
@ -165,6 +169,10 @@ struct ColumnInfo {
return false;
}
if (element_level_ != other.element_level_) {
return false;
}
for (int i = 0; i < nested_path_.size(); ++i) {
if (nested_path_[i] != other.nested_path_[i]) {
return false;
@ -180,21 +188,25 @@ struct ColumnInfo {
data_type_,
element_type_,
nested_path_,
nullable_) < std::tie(other.field_id_,
other.data_type_,
other.element_type_,
other.nested_path_,
other.nullable_);
nullable_,
element_level_) < std::tie(other.field_id_,
other.data_type_,
other.element_type_,
other.nested_path_,
other.nullable_,
other.element_level_);
}
std::string
ToString() const {
return fmt::format(
"[FieldId:{}, data_type:{}, element_type:{}, nested_path:{}]",
"[FieldId:{}, data_type:{}, element_type:{}, nested_path:{}, "
"element_level:{}]",
std::to_string(field_id_.get()),
data_type_,
element_type_,
milvus::Join<std::string>(nested_path_, ","));
milvus::Join<std::string>(nested_path_, ","),
element_level_);
}
};

View File

@ -183,6 +183,126 @@ class FilterBitsNode : public PlanNode {
const expr::TypedExprPtr filter_;
};
class ElementFilterNode : public PlanNode {
public:
ElementFilterNode(const PlanNodeId& id,
expr::TypedExprPtr element_filter,
std::string struct_name,
std::vector<PlanNodePtr> sources)
: PlanNode(id),
sources_{std::move(sources)},
element_filter_(std::move(element_filter)),
struct_name_(std::move(struct_name)) {
AssertInfo(
element_filter_->type() == DataType::BOOL,
fmt::format(
"Element filter expression must be of type BOOLEAN, Got {}",
element_filter_->type()));
}
DataType
output_type() const override {
return DataType::NONE;
}
std::vector<PlanNodePtr>
sources() const override {
return sources_;
}
const expr::TypedExprPtr&
element_filter() const {
return element_filter_;
}
const std::string&
struct_name() const {
return struct_name_;
}
std::string_view
name() const override {
return "ElementFilter";
}
std::string
ToString() const override {
return fmt::format(
"ElementFilterNode:\n\t[struct_name:{}, element_filter:{}]",
struct_name_,
element_filter_->ToString());
}
private:
const std::vector<PlanNodePtr> sources_;
const expr::TypedExprPtr element_filter_;
const std::string struct_name_;
};
class ElementFilterBitsNode : public PlanNode {
public:
ElementFilterBitsNode(
const PlanNodeId& id,
expr::TypedExprPtr element_filter,
std::string struct_name,
std::vector<PlanNodePtr> sources = std::vector<PlanNodePtr>{})
: PlanNode(id),
sources_{std::move(sources)},
element_filter_(std::move(element_filter)),
struct_name_(std::move(struct_name)) {
AssertInfo(
element_filter_->type() == DataType::BOOL,
fmt::format(
"Element filter expression must be of type BOOLEAN, Got {}",
element_filter_->type()));
}
DataType
output_type() const override {
return DataType::BOOL;
}
std::vector<PlanNodePtr>
sources() const override {
return sources_;
}
const expr::TypedExprPtr&
element_filter() const {
return element_filter_;
}
const std::string&
struct_name() const {
return struct_name_;
}
std::string_view
name() const override {
return "ElementFilterBits";
}
std::string
ToString() const override {
return fmt::format(
"ElementFilterBitsNode:\n\t[struct_name:{}, element_filter:{}]",
struct_name_,
element_filter_->ToString());
}
expr::ExprInfo
GatherInfo() const override {
expr::ExprInfo info;
element_filter_->GatherInfo(info);
return info;
}
private:
const std::vector<PlanNodePtr> sources_;
const expr::TypedExprPtr element_filter_;
const std::string struct_name_;
};
class MvccNode : public PlanNode {
public:
MvccNode(const PlanNodeId& id,

View File

@ -55,7 +55,12 @@ ExecPlanNodeVisitor::ExecuteTask(
for (;;) {
auto result = task->Next();
if (!result) {
Assert(processed_num == query_context->get_active_count());
if (query_context->get_active_element_count() > 0) {
Assert(processed_num ==
query_context->get_active_element_count());
} else {
Assert(processed_num == query_context->get_active_count());
}
break;
}
auto childrens = result->childrens();

View File

@ -31,29 +31,37 @@ ParsePlaceholderGroup(const Plan* plan,
}
bool
check_data_type(const FieldMeta& field_meta,
const milvus::proto::common::PlaceholderType type) {
check_data_type(
const FieldMeta& field_meta,
const milvus::proto::common::PlaceholderValue& placeholder_value) {
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
if (field_meta.get_element_type() == DataType::VECTOR_FLOAT) {
return type ==
milvus::proto::common::PlaceholderType::EmbListFloatVector;
if (placeholder_value.element_level()) {
return placeholder_value.type() ==
milvus::proto::common::PlaceholderType::FloatVector;
} else {
return placeholder_value.type() ==
milvus::proto::common::PlaceholderType::
EmbListFloatVector;
}
} else if (field_meta.get_element_type() == DataType::VECTOR_FLOAT16) {
return type ==
return placeholder_value.type() ==
milvus::proto::common::PlaceholderType::EmbListFloat16Vector;
} else if (field_meta.get_element_type() == DataType::VECTOR_BFLOAT16) {
return type == milvus::proto::common::PlaceholderType::
EmbListBFloat16Vector;
return placeholder_value.type() ==
milvus::proto::common::PlaceholderType::
EmbListBFloat16Vector;
} else if (field_meta.get_element_type() == DataType::VECTOR_BINARY) {
return type ==
return placeholder_value.type() ==
milvus::proto::common::PlaceholderType::EmbListBinaryVector;
} else if (field_meta.get_element_type() == DataType::VECTOR_INT8) {
return type ==
return placeholder_value.type() ==
milvus::proto::common::PlaceholderType::EmbListInt8Vector;
}
return false;
}
return static_cast<int>(field_meta.get_data_type()) ==
static_cast<int>(type);
static_cast<int>(placeholder_value.type());
}
std::unique_ptr<PlaceholderGroup>
@ -64,30 +72,33 @@ ParsePlaceholderGroup(const Plan* plan,
milvus::proto::common::PlaceholderGroup ph_group;
auto ok = ph_group.ParseFromArray(blob, blob_len);
Assert(ok);
for (auto& info : ph_group.placeholders()) {
for (auto& ph : ph_group.placeholders()) {
Placeholder element;
element.tag_ = info.tag();
element.tag_ = ph.tag();
element.element_level_ = ph.element_level();
Assert(plan->tag2field_.count(element.tag_));
auto field_id = plan->tag2field_.at(element.tag_);
auto& field_meta = plan->schema_->operator[](field_id);
AssertInfo(check_data_type(field_meta, info.type()),
AssertInfo(check_data_type(field_meta, ph),
"vector type must be the same, field {} - type {}, search "
"info type {}",
"ph type {}",
field_meta.get_name().get(),
field_meta.get_data_type(),
static_cast<DataType>(info.type()));
element.num_of_queries_ = info.values_size();
static_cast<DataType>(ph.type()));
element.num_of_queries_ = ph.values_size();
AssertInfo(element.num_of_queries_ > 0, "must have queries");
if (info.type() ==
if (ph.type() ==
milvus::proto::common::PlaceholderType::SparseFloatVector) {
element.sparse_matrix_ =
SparseBytesToRows(info.values(), /*validate=*/true);
SparseBytesToRows(ph.values(), /*validate=*/true);
} else {
auto line_size = info.values().Get(0).size();
auto line_size = ph.values().Get(0).size();
auto& target = element.blob_;
if (field_meta.get_data_type() != DataType::VECTOR_ARRAY) {
if (field_meta.get_sizeof() != line_size) {
if (field_meta.get_data_type() != DataType::VECTOR_ARRAY ||
ph.element_level()) {
if (field_meta.get_sizeof() != line_size &&
!ph.element_level()) {
ThrowInfo(DimNotMatch,
fmt::format(
"vector dimension mismatch, expected vector "
@ -96,7 +107,7 @@ ParsePlaceholderGroup(const Plan* plan,
line_size));
}
target.reserve(line_size * element.num_of_queries_);
for (auto& line : info.values()) {
for (auto& line : ph.values()) {
AssertInfo(line_size == line.size(),
"vector dimension mismatch, expected vector "
"size(byte) {}, actual {}.",
@ -118,7 +129,7 @@ ParsePlaceholderGroup(const Plan* plan,
auto bytes_per_vec = milvus::vector_bytes_per_element(
field_meta.get_element_type(), dim);
for (auto& line : info.values()) {
for (auto& line : ph.values()) {
target.insert(target.end(), line.begin(), line.end());
AssertInfo(
line.size() % bytes_per_vec == 0,

View File

@ -85,6 +85,7 @@ struct Placeholder {
sparse_matrix_;
// offsets for embedding list
aligned_vector<size_t> offsets_;
bool element_level_{false};
const void*
get_blob() const {

View File

@ -63,162 +63,201 @@ MergeExprWithNamespace(const SchemaPtr schema,
return and_expr;
}
SearchInfo
ProtoParser::ParseSearchInfo(const planpb::VectorANNS& anns_proto) {
SearchInfo search_info;
auto& query_info_proto = anns_proto.query_info();
auto field_id = FieldId(anns_proto.field_id());
search_info.field_id_ = field_id;
search_info.metric_type_ = query_info_proto.metric_type();
search_info.topk_ = query_info_proto.topk();
search_info.round_decimal_ = query_info_proto.round_decimal();
search_info.search_params_ =
nlohmann::json::parse(query_info_proto.search_params());
search_info.materialized_view_involved =
query_info_proto.materialized_view_involved();
// currently, iterative filter does not support range search
if (!search_info.search_params_.contains(RADIUS)) {
if (query_info_proto.hints() != "") {
if (query_info_proto.hints() == "disable") {
search_info.iterative_filter_execution = false;
} else if (query_info_proto.hints() == ITERATIVE_FILTER) {
search_info.iterative_filter_execution = true;
} else {
// check if hints is valid
ThrowInfo(ConfigInvalid,
"hints: {} not supported",
query_info_proto.hints());
}
} else if (search_info.search_params_.contains(HINTS)) {
if (search_info.search_params_[HINTS] == ITERATIVE_FILTER) {
search_info.iterative_filter_execution = true;
} else {
// check if hints is valid
ThrowInfo(ConfigInvalid,
"hints: {} not supported",
search_info.search_params_[HINTS]);
}
}
}
if (query_info_proto.bm25_avgdl() > 0) {
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
query_info_proto.bm25_avgdl();
}
if (query_info_proto.group_by_field_id() > 0) {
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
search_info.group_by_field_id_ = group_by_field_id;
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
search_info.strict_group_size_ = query_info_proto.strict_group_size();
// Always set json_path to distinguish between unset and empty string
// Empty string means accessing the entire JSON object
search_info.json_path_ = query_info_proto.json_path();
if (query_info_proto.json_type() !=
milvus::proto::schema::DataType::None) {
search_info.json_type_ =
static_cast<milvus::DataType>(query_info_proto.json_type());
}
search_info.strict_cast_ = query_info_proto.strict_cast();
}
if (query_info_proto.has_search_iterator_v2_info()) {
auto& iterator_v2_info_proto =
query_info_proto.search_iterator_v2_info();
search_info.iterator_v2_info_ = SearchIteratorV2Info{
.token = iterator_v2_info_proto.token(),
.batch_size = iterator_v2_info_proto.batch_size(),
};
if (iterator_v2_info_proto.has_last_bound()) {
search_info.iterator_v2_info_->last_bound =
iterator_v2_info_proto.last_bound();
}
}
return search_info;
}
std::unique_ptr<VectorPlanNode>
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
Assert(plan_node_proto.has_vector_anns());
auto& anns_proto = plan_node_proto.vector_anns();
auto expr_parser = [&]() -> plan::PlanNodePtr {
auto expr = ParseExprs(anns_proto.predicates());
if (plan_node_proto.has_namespace_()) {
expr = MergeExprWithNamespace(
schema, expr, plan_node_proto.namespace_());
}
return std::make_shared<plan::FilterBitsNode>(
milvus::plan::GetNextPlanNodeId(), expr);
};
auto search_info_parser = [&]() -> SearchInfo {
SearchInfo search_info;
auto& query_info_proto = anns_proto.query_info();
auto field_id = FieldId(anns_proto.field_id());
search_info.field_id_ = field_id;
search_info.metric_type_ = query_info_proto.metric_type();
search_info.topk_ = query_info_proto.topk();
search_info.round_decimal_ = query_info_proto.round_decimal();
search_info.search_params_ =
nlohmann::json::parse(query_info_proto.search_params());
search_info.materialized_view_involved =
query_info_proto.materialized_view_involved();
// currently, iterative filter does not support range search
if (!search_info.search_params_.contains(RADIUS)) {
if (query_info_proto.hints() != "") {
if (query_info_proto.hints() == "disable") {
search_info.iterative_filter_execution = false;
} else if (query_info_proto.hints() == ITERATIVE_FILTER) {
search_info.iterative_filter_execution = true;
} else {
// check if hints is valid
ThrowInfo(ConfigInvalid,
"hints: {} not supported",
query_info_proto.hints());
}
} else if (search_info.search_params_.contains(HINTS)) {
if (search_info.search_params_[HINTS] == ITERATIVE_FILTER) {
search_info.iterative_filter_execution = true;
} else {
// check if hints is valid
ThrowInfo(ConfigInvalid,
"hints: {} not supported",
search_info.search_params_[HINTS]);
}
}
}
if (query_info_proto.bm25_avgdl() > 0) {
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
query_info_proto.bm25_avgdl();
}
if (query_info_proto.group_by_field_id() > 0) {
auto group_by_field_id =
FieldId(query_info_proto.group_by_field_id());
search_info.group_by_field_id_ = group_by_field_id;
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
search_info.strict_group_size_ =
query_info_proto.strict_group_size();
// Always set json_path to distinguish between unset and empty string
// Empty string means accessing the entire JSON object
search_info.json_path_ = query_info_proto.json_path();
if (query_info_proto.json_type() !=
milvus::proto::schema::DataType::None) {
search_info.json_type_ =
static_cast<milvus::DataType>(query_info_proto.json_type());
}
search_info.strict_cast_ = query_info_proto.strict_cast();
}
if (query_info_proto.has_search_iterator_v2_info()) {
auto& iterator_v2_info_proto =
query_info_proto.search_iterator_v2_info();
search_info.iterator_v2_info_ = SearchIteratorV2Info{
.token = iterator_v2_info_proto.token(),
.batch_size = iterator_v2_info_proto.batch_size(),
};
if (iterator_v2_info_proto.has_last_bound()) {
search_info.iterator_v2_info_->last_bound =
iterator_v2_info_proto.last_bound();
}
}
return search_info;
};
// Parse search information from proto
auto plan_node = std::make_unique<VectorPlanNode>();
plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
plan_node->search_info_ = std::move(search_info_parser());
plan_node->search_info_ = ParseSearchInfo(anns_proto);
milvus::plan::PlanNodePtr plannode;
std::vector<milvus::plan::PlanNodePtr> sources;
// mvcc node -> vector search node -> iterative filter node
auto iterative_filter_plan = [&]() {
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId());
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
auto expr = ParseExprs(anns_proto.predicates());
if (plan_node_proto.has_namespace_()) {
expr = MergeExprWithNamespace(
schema, expr, plan_node_proto.namespace_());
}
plannode = std::make_shared<plan::FilterNode>(
milvus::plan::GetNextPlanNodeId(), expr, sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
};
// pre filter node -> mvcc node -> vector search node
auto pre_filter_plan = [&]() {
plannode = std::move(expr_parser());
if (plan_node->search_info_.materialized_view_involved) {
const auto expr_info = plannode->GatherInfo();
knowhere::MaterializedViewSearchInfo materialized_view_search_info;
for (const auto& [expr_field_id, vals] :
expr_info.field_id_to_values) {
materialized_view_search_info
.field_id_to_touched_categories_cnt[expr_field_id] =
vals.size();
}
materialized_view_search_info.is_pure_and = expr_info.is_pure_and;
materialized_view_search_info.has_not = expr_info.has_not;
plan_node->search_info_
.search_params_[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] =
materialized_view_search_info;
}
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
};
// Build plan node chain based on predicate and filter execution strategy
if (anns_proto.has_predicates()) {
// currently limit iterative filter scope to search only
if (plan_node->search_info_.iterative_filter_execution &&
plan_node->search_info_.group_by_field_id_ == std::nullopt) {
iterative_filter_plan();
auto* predicate_proto = &anns_proto.predicates();
bool is_element_level = predicate_proto->expr_case() ==
proto::plan::Expr::kElementFilterExpr;
// Parse expressions based on filter type (similar to RandomSampleExpr pattern)
expr::TypedExprPtr element_expr = nullptr;
expr::TypedExprPtr doc_expr = nullptr;
std::string struct_name;
if (is_element_level) {
// Element-level query: extract both element_expr and optional doc-level predicate
auto& element_filter_expr = predicate_proto->element_filter_expr();
element_expr = ParseExprs(element_filter_expr.element_expr());
struct_name = element_filter_expr.struct_name();
if (element_filter_expr.has_predicate()) {
doc_expr = ParseExprs(element_filter_expr.predicate());
if (plan_node_proto.has_namespace_()) {
doc_expr = MergeExprWithNamespace(
schema, doc_expr, plan_node_proto.namespace_());
}
}
} else {
pre_filter_plan();
// Document-level query: only doc expr
doc_expr = ParseExprs(anns_proto.predicates());
if (plan_node_proto.has_namespace_()) {
doc_expr = MergeExprWithNamespace(
schema, doc_expr, plan_node_proto.namespace_());
}
}
bool is_iterative =
plan_node->search_info_.iterative_filter_execution &&
plan_node->search_info_.group_by_field_id_ == std::nullopt;
if (is_iterative) {
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId());
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
// Add element-level filter if needed
if (is_element_level) {
plannode = std::make_shared<plan::ElementFilterNode>(
milvus::plan::GetNextPlanNodeId(),
element_expr,
struct_name,
sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
// Add doc-level filter if present
if (doc_expr) {
plannode = std::make_shared<plan::FilterNode>(
milvus::plan::GetNextPlanNodeId(), doc_expr, sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
} else {
if (doc_expr) {
plannode = std::make_shared<plan::FilterBitsNode>(
milvus::plan::GetNextPlanNodeId(), doc_expr);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
if (!is_element_level &&
plan_node->search_info_.materialized_view_involved) {
const auto expr_info = plannode->GatherInfo();
knowhere::MaterializedViewSearchInfo
materialized_view_search_info;
for (const auto& [expr_field_id, vals] :
expr_info.field_id_to_values) {
materialized_view_search_info
.field_id_to_touched_categories_cnt[expr_field_id] =
vals.size();
}
materialized_view_search_info.is_pure_and =
expr_info.is_pure_and;
materialized_view_search_info.has_not = expr_info.has_not;
plan_node->search_info_.search_params_
[knowhere::meta::MATERIALIZED_VIEW_SEARCH_INFO] =
materialized_view_search_info;
}
}
plannode = std::make_shared<milvus::plan::MvccNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
if (is_element_level) {
plannode = std::make_shared<plan::ElementFilterBitsNode>(
milvus::plan::GetNextPlanNodeId(),
element_expr,
struct_name,
sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
plannode = std::make_shared<milvus::plan::VectorSearchNode>(
milvus::plan::GetNextPlanNodeId(), sources);
sources = std::vector<milvus::plan::PlanNodePtr>{plannode};
}
} else {
// no filter, force set iterative filter hint to false, go with normal vector search path
@ -400,8 +439,16 @@ expr::TypedExprPtr
ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
std::vector<::milvus::proto::plan::GenericValue> extra_values;
for (auto val : expr_pb.extra_values()) {
extra_values.emplace_back(val);
@ -417,8 +464,16 @@ expr::TypedExprPtr
ProtoParser::ParseNullExprs(const proto::plan::NullExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
return std::make_shared<milvus::expr::NullExpr>(
expr::ColumnInfo(column_info), expr_pb.op());
}
@ -428,8 +483,15 @@ ProtoParser::ParseBinaryRangeExprs(
const proto::plan::BinaryRangeExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
return std::make_shared<expr::BinaryRangeFilterExpr>(
columnInfo,
expr_pb.lower_value(),
@ -443,8 +505,15 @@ ProtoParser::ParseTimestamptzArithCompareExprs(
const proto::plan::TimestamptzArithCompareExpr& expr_pb) {
auto& columnInfo = expr_pb.timestamptz_column();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
return std::make_shared<expr::TimestamptzArithCompareExpr>(
columnInfo,
expr_pb.arith_op(),
@ -453,6 +522,17 @@ ProtoParser::ParseTimestamptzArithCompareExprs(
expr_pb.compare_value());
}
expr::TypedExprPtr
ProtoParser::ParseElementFilterExprs(
const proto::plan::ElementFilterExpr& expr_pb) {
// ElementFilterExpr is not a regular expression that can be evaluated directly.
// It should be handled at the PlanNode level (in PlanNodeFromProto).
// This method should never be called.
ThrowInfo(ExprInvalid,
"ParseElementFilterExprs should not be called directly. "
"ElementFilterExpr must be handled at PlanNode level.");
}
expr::TypedExprPtr
ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) {
std::vector<expr::TypedExprPtr> parameters;
@ -480,15 +560,31 @@ expr::TypedExprPtr
ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
auto& left_column_info = expr_pb.left_column_info();
auto left_field_id = FieldId(left_column_info.field_id());
auto left_data_type = schema->operator[](left_field_id).get_data_type();
Assert(left_data_type ==
static_cast<DataType>(left_column_info.data_type()));
auto& left_field = schema->operator[](left_field_id);
auto left_data_type = left_field.get_data_type();
if (left_column_info.is_element_level()) {
Assert(left_data_type == DataType::ARRAY);
Assert(left_field.get_element_type() ==
static_cast<DataType>(left_column_info.data_type()));
} else {
Assert(left_data_type ==
static_cast<DataType>(left_column_info.data_type()));
}
auto& right_column_info = expr_pb.right_column_info();
auto right_field_id = FieldId(right_column_info.field_id());
auto right_data_type = schema->operator[](right_field_id).get_data_type();
Assert(right_data_type ==
static_cast<DataType>(right_column_info.data_type()));
auto& right_field = schema->operator[](right_field_id);
auto right_data_type = right_field.get_data_type();
if (right_column_info.is_element_level()) {
Assert(right_data_type == DataType::ARRAY);
Assert(right_field.get_element_type() ==
static_cast<DataType>(right_column_info.data_type()));
} else {
Assert(right_data_type ==
static_cast<DataType>(right_column_info.data_type()));
}
return std::make_shared<expr::CompareExpr>(left_field_id,
right_field_id,
@ -501,8 +597,15 @@ expr::TypedExprPtr
ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
std::vector<::milvus::proto::plan::GenericValue> values;
for (size_t i = 0; i < expr_pb.values_size(); i++) {
values.emplace_back(expr_pb.values(i));
@ -532,8 +635,16 @@ ProtoParser::ParseBinaryArithOpEvalRangeExprs(
const proto::plan::BinaryArithOpEvalRangeExpr& expr_pb) {
auto& column_info = expr_pb.column_info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
return std::make_shared<expr::BinaryArithOpEvalRangeExpr>(
column_info,
expr_pb.op(),
@ -546,8 +657,16 @@ expr::TypedExprPtr
ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
auto& column_info = expr_pb.info();
auto field_id = FieldId(column_info.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
return std::make_shared<expr::ExistsExpr>(column_info);
}
@ -556,8 +675,15 @@ ProtoParser::ParseJsonContainsExprs(
const proto::plan::JSONContainsExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
std::vector<::milvus::proto::plan::GenericValue> values;
for (size_t i = 0; i < expr_pb.elements_size(); i++) {
values.emplace_back(expr_pb.elements(i));
@ -584,8 +710,15 @@ ProtoParser::ParseGISFunctionFilterExprs(
const proto::plan::GISFunctionFilterExpr& expr_pb) {
auto& columnInfo = expr_pb.column_info();
auto field_id = FieldId(columnInfo.field_id());
auto data_type = schema->operator[](field_id).get_data_type();
Assert(data_type == (DataType)columnInfo.data_type());
auto& field = schema->operator[](field_id);
auto data_type = field.get_data_type();
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
auto expr = std::make_shared<expr::GISFunctionFilterExpr>(
columnInfo, expr_pb.op(), expr_pb.wkt_string(), expr_pb.distance());
@ -671,6 +804,11 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
expr_pb.timestamptz_arith_compare_expr());
break;
}
case ppe::kElementFilterExpr: {
ThrowInfo(ExprInvalid,
"ElementFilterExpr should be handled at PlanNode level, "
"not in ParseExprs");
}
default: {
std::string s;
google::protobuf::TextFormat::PrintToString(expr_pb, &s);

View File

@ -106,6 +106,9 @@ class ProtoParser {
ParseTimestamptzArithCompareExprs(
const proto::plan::TimestamptzArithCompareExpr& expr_pb);
expr::TypedExprPtr
ParseElementFilterExprs(const proto::plan::ElementFilterExpr& expr_pb);
expr::TypedExprPtr
ParseValueExprs(const proto::plan::ValueExpr& expr_pb);
@ -113,6 +116,9 @@ class ProtoParser {
PlanOptionsFromProto(const proto::plan::PlanOption& plan_option_proto,
PlanOptions& plan_options);
SearchInfo
ParseSearchInfo(const proto::plan::VectorANNS& anns_proto);
private:
const SchemaPtr schema;
};

View File

@ -140,7 +140,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_cfg[knowhere::meta::RANGE_SEARCH_K] = topk;
sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk);
// For vector array (embedding list), element type is used to determine how to operate search.
@ -196,8 +196,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
auto result =
ReGenRangeSearchResult(res.value(), topk, nq, query_ds.metric_type);
milvus::tracer::AddEvent("ReGenRangeSearchResult");
std::copy_n(
GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets());
std::copy_n(GetDatasetIDs(result), nq * topk, sub_result.get_offsets());
std::copy_n(
GetDatasetDistance(result), nq * topk, sub_result.get_distances());
} else {
@ -206,7 +205,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
stat = knowhere::BruteForce::SearchWithBuf<float>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset,
@ -215,7 +214,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
stat = knowhere::BruteForce::SearchWithBuf<float16>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset,
@ -224,7 +223,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
stat = knowhere::BruteForce::SearchWithBuf<bfloat16>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset,
@ -233,7 +232,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
stat = knowhere::BruteForce::SearchWithBuf<bin1>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset,
@ -242,7 +241,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
stat = knowhere::BruteForce::SearchSparseWithBuf(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset,
@ -251,7 +250,7 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
stat = knowhere::BruteForce::SearchWithBuf<int8>(
base_dataset,
query_dataset,
sub_result.mutable_seg_offsets().data(),
sub_result.mutable_offsets().data(),
sub_result.mutable_distances().data(),
search_cfg,
bitset,

View File

@ -129,7 +129,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
nullptr);
for (int i = 0; i < nq; i++) {
auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk);
auto ans = result.get_seg_offsets() + i * topk;
auto ans = result.get_offsets() + i * topk;
AssertMatch(ref, ans);
}
@ -146,7 +146,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
for (int i = 0; i < nq; i++) {
auto ref = RangeSearchRef(
base.get(), *(query.get() + i), nb, 0.1, 0.5, topk);
auto ans = result2.get_seg_offsets() + i * topk;
auto ans = result2.get_offsets() + i * topk;
AssertMatch(ref, ans);
}

View File

@ -151,7 +151,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
dim,
topk,
metric_type);
auto ans = result.get_seg_offsets() + i * topk;
auto ans = result.get_offsets() + i * topk;
AssertMatch(ref, ans);
}
}

View File

@ -81,7 +81,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
SearchResult& search_result) {
auto& schema = segment.get_schema();
auto& record = segment.get_insert_record();
auto active_count =
auto active_row_count =
std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
// step 1.1: get meta
@ -162,7 +162,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
CachedSearchIterator cached_iter(search_dataset,
vec_ptr,
active_count,
active_row_count,
info,
index_info,
bitset,
@ -172,23 +172,30 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
}
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
auto max_chunk = upper_div(active_count, vec_size_per_chunk);
auto max_chunk = upper_div(active_row_count, vec_size_per_chunk);
// embedding search embedding on embedding list
bool embedding_search = false;
if (data_type == DataType::VECTOR_ARRAY &&
info.array_offsets_ != nullptr) {
embedding_search = true;
}
for (int chunk_id = current_chunk_id; chunk_id < max_chunk;
++chunk_id) {
auto chunk_data = vec_ptr->get_chunk_data(chunk_id);
auto element_begin = chunk_id * vec_size_per_chunk;
auto element_end =
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = element_end - element_begin;
auto row_begin = chunk_id * vec_size_per_chunk;
auto row_end =
std::min(active_row_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = row_end - row_begin;
query::dataset::RawDataset sub_data;
std::unique_ptr<uint8_t[]> buf = nullptr;
std::vector<size_t> offsets;
if (data_type != DataType::VECTOR_ARRAY) {
sub_data = query::dataset::RawDataset{
element_begin, dim, size_per_chunk, chunk_data};
row_begin, dim, size_per_chunk, chunk_data};
} else {
// TODO(SpadeA): For VectorArray(Embedding List), data is
// discreted stored in FixedVector which means we will copy the
@ -201,43 +208,59 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
}
buf = std::make_unique<uint8_t[]>(size);
offsets.reserve(size_per_chunk + 1);
offsets.push_back(0);
auto offset = 0;
auto ptr = buf.get();
for (int i = 0; i < size_per_chunk; ++i) {
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
ptr += vec_ptr[i].byte_size();
if (embedding_search) {
auto count = 0;
auto ptr = buf.get();
for (int i = 0; i < size_per_chunk; ++i) {
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
ptr += vec_ptr[i].byte_size();
count += vec_ptr[i].length();
}
sub_data = query::dataset::RawDataset{
row_begin, dim, count, buf.get()};
} else {
offsets.reserve(size_per_chunk + 1);
offsets.push_back(0);
offset += vec_ptr[i].length();
offsets.push_back(offset);
auto offset = 0;
auto ptr = buf.get();
for (int i = 0; i < size_per_chunk; ++i) {
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
ptr += vec_ptr[i].byte_size();
offset += vec_ptr[i].length();
offsets.push_back(offset);
}
sub_data = query::dataset::RawDataset{row_begin,
dim,
size_per_chunk,
buf.get(),
offsets.data()};
}
sub_data = query::dataset::RawDataset{element_begin,
dim,
size_per_chunk,
buf.get(),
offsets.data()};
}
if (data_type == DataType::VECTOR_ARRAY) {
AssertInfo(
query_offsets != nullptr,
"query_offsets is nullptr, but data_type is vector array");
auto vector_type = data_type;
if (embedding_search) {
vector_type = element_type;
}
if (milvus::exec::UseVectorIterator(info)) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
AssertInfo(vector_type != DataType::VECTOR_ARRAY,
"vector array(embedding list) is not supported for "
"vector iterator");
if (buf != nullptr) {
search_result.chunk_buffers_.emplace_back(std::move(buf));
}
auto sub_qr =
PackBruteForceSearchIteratorsIntoSubResult(search_dataset,
sub_data,
info,
index_info,
bitset,
data_type);
vector_type);
final_qr.merge(sub_qr);
} else {
auto sub_qr = BruteForceSearch(search_dataset,
@ -245,7 +268,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
info,
index_info,
bitset,
data_type,
vector_type,
element_type,
op_context);
final_qr.merge(sub_qr);
@ -259,9 +282,18 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
search_result.AssembleChunkVectorIterators(
num_queries, max_chunk, chunk_rows, final_qr.chunk_iterators());
} else {
if (info.array_offsets_ != nullptr) {
auto [seg_offsets, elem_indicies] =
final_qr.convert_to_element_offsets(
info.array_offsets_.get());
search_result.seg_offsets_ = std::move(seg_offsets);
search_result.element_indices_ = std::move(elem_indicies);
search_result.element_level_ = true;
} else {
search_result.seg_offsets_ =
std::move(final_qr.mutable_offsets());
}
search_result.distances_ = std::move(final_qr.mutable_distances());
search_result.seg_offsets_ =
std::move(final_qr.mutable_seg_offsets());
}
search_result.unity_topK_ = topk;
search_result.total_nq_ = num_queries;

View File

@ -96,6 +96,28 @@ SearchOnSealedIndex(const Schema& schema,
std::round(distances[i] * multiplier) / multiplier;
}
}
// Handle element-level conversion if needed
if (search_info.array_offsets_ != nullptr) {
std::vector<int64_t> element_ids =
std::move(search_result.seg_offsets_);
search_result.seg_offsets_.resize(element_ids.size());
search_result.element_indices_.resize(element_ids.size());
for (size_t i = 0; i < element_ids.size(); i++) {
if (element_ids[i] == INVALID_SEG_OFFSET) {
search_result.seg_offsets_[i] = INVALID_SEG_OFFSET;
search_result.element_indices_[i] = -1;
} else {
auto [doc_id, elem_index] =
search_info.array_offsets_->ElementIDToRowID(
element_ids[i]);
search_result.seg_offsets_[i] = doc_id;
search_result.element_indices_[i] = elem_index;
}
}
search_result.element_level_ = true;
}
}
search_result.total_nq_ = num_queries;
search_result.unity_topK_ = topK;
@ -150,6 +172,17 @@ SearchOnSealedColumn(const Schema& schema,
search_info.metric_type_,
search_info.round_decimal_);
// For element-level search (embedding-search-embedding), we need to use
// element count instead of row count
bool is_element_level_search =
field.get_data_type() == DataType::VECTOR_ARRAY &&
query_offsets == nullptr;
if (is_element_level_search) {
// embedding-search-embedding on embedding list pattern
data_type = element_type;
}
auto offset = 0;
auto vector_chunks = column->GetAllChunks(op_context);
@ -157,6 +190,14 @@ SearchOnSealedColumn(const Schema& schema,
auto pw = vector_chunks[i];
auto vec_data = pw.get()->Data();
auto chunk_size = column->chunk_row_nums(i);
// For element-level search, get element count from VectorArrayOffsets
if (is_element_level_search) {
auto elem_offsets_pw = column->VectorArrayOffsets(op_context, i);
// offsets[row_count] gives total element count in this chunk
chunk_size = elem_offsets_pw.get()[chunk_size];
}
auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
@ -201,8 +242,17 @@ SearchOnSealedColumn(const Schema& schema,
column->GetNumRowsUntilChunk(),
final_qr.chunk_iterators());
} else {
if (search_info.array_offsets_ != nullptr) {
auto [seg_offsets, elem_indicies] =
final_qr.convert_to_element_offsets(
search_info.array_offsets_.get());
result.seg_offsets_ = std::move(seg_offsets);
result.element_indices_ = std::move(elem_indicies);
result.element_level_ = true;
} else {
result.seg_offsets_ = std::move(final_qr.mutable_offsets());
}
result.distances_ = std::move(final_qr.mutable_distances());
result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
}
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;

View File

@ -30,7 +30,7 @@ SubSearchResult::merge_impl(const SubSearchResult& right) {
for (int64_t qn = 0; qn < num_queries_; ++qn) {
auto offset = qn * topk_;
int64_t* __restrict__ left_ids = this->get_seg_offsets() + offset;
int64_t* __restrict__ left_ids = this->get_offsets() + offset;
float* __restrict__ left_distances = this->get_distances() + offset;
auto right_ids = right.get_ids() + offset;

View File

@ -18,6 +18,7 @@
#include "common/Types.h"
#include "common/Utils.h"
#include "knowhere/index/index_node.h"
#include "common/ArrayOffsets.h"
namespace milvus::query {
class SubSearchResult {
@ -31,7 +32,7 @@ class SubSearchResult {
topk_(topk),
round_decimal_(round_decimal),
metric_type_(metric_type),
seg_offsets_(num_queries * topk, INVALID_SEG_OFFSET),
offsets_(num_queries * topk, INVALID_SEG_OFFSET),
distances_(num_queries * topk, init_value(metric_type)),
chunk_iterators_(std::move(iters)) {
}
@ -52,7 +53,7 @@ class SubSearchResult {
topk_(other.topk_),
round_decimal_(other.round_decimal_),
metric_type_(std::move(other.metric_type_)),
seg_offsets_(std::move(other.seg_offsets_)),
offsets_(std::move(other.offsets_)),
distances_(std::move(other.distances_)),
chunk_iterators_(std::move(other.chunk_iterators_)) {
}
@ -77,12 +78,12 @@ class SubSearchResult {
const int64_t*
get_ids() const {
return seg_offsets_.data();
return offsets_.data();
}
int64_t*
get_seg_offsets() {
return seg_offsets_.data();
get_offsets() {
return offsets_.data();
}
const float*
@ -96,8 +97,8 @@ class SubSearchResult {
}
auto&
mutable_seg_offsets() {
return seg_offsets_;
mutable_offsets() {
return offsets_;
}
auto&
@ -116,6 +117,27 @@ class SubSearchResult {
return this->chunk_iterators_;
}
std::pair<std::vector<int64_t>, std::vector<int32_t>>
convert_to_element_offsets(const IArrayOffsets* array_offsets) {
std::vector<int64_t> doc_offsets;
std::vector<int32_t> element_indices;
doc_offsets.reserve(offsets_.size());
element_indices.reserve(offsets_.size());
for (size_t i = 0; i < offsets_.size(); i++) {
if (offsets_[i] == INVALID_SEG_OFFSET) {
doc_offsets.push_back(INVALID_SEG_OFFSET);
element_indices.push_back(-1);
} else {
auto [doc_id, elem_index] =
array_offsets->ElementIDToRowID(offsets_[i]);
doc_offsets.push_back(doc_id);
element_indices.push_back(elem_index);
}
}
return std::make_pair(std::move(doc_offsets),
std::move(element_indices));
}
private:
template <bool is_desc>
void
@ -126,7 +148,7 @@ class SubSearchResult {
int64_t topk_;
int64_t round_decimal_;
knowhere::MetricType metric_type_;
std::vector<int64_t> seg_offsets_;
std::vector<int64_t> offsets_;
std::vector<float> distances_;
std::vector<knowhere::IndexNode::IteratorPtr> chunk_iterators_;
};

View File

@ -55,7 +55,7 @@ GenSubSearchResult(const int64_t nq,
}
}
sub_result->mutable_distances() = std::move(distances);
sub_result->mutable_seg_offsets() = std::move(ids);
sub_result->mutable_offsets() = std::move(ids);
return sub_result;
}
@ -72,7 +72,7 @@ CheckSubSearchResult(const int64_t nq,
auto ref_x = result_ref[n].top();
result_ref[n].pop();
auto index = n * topk + topk - 1 - k;
auto id = result.get_seg_offsets()[index];
auto id = result.get_offsets()[index];
auto distance = result.get_distances()[index];
ASSERT_EQ(id, ref_x);
ASSERT_EQ(distance, ref_x);

View File

@ -2676,22 +2676,62 @@ ChunkedSegmentSealedImpl::load_field_data_common(
bool generated_interim_index = generate_interim_index(field_id, num_rows);
std::unique_lock lck(mutex_);
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
"field {} data already loaded",
field_id.get());
set_bit(field_data_ready_bitset_, field_id, true);
update_row_count(num_rows);
if (generated_interim_index) {
auto column = get_column(field_id);
if (column) {
column->ManualEvictCache();
std::string struct_name;
const FieldMeta* field_meta_ptr = nullptr;
{
std::unique_lock lck(mutex_);
AssertInfo(!get_bit(field_data_ready_bitset_, field_id),
"field {} data already loaded",
field_id.get());
set_bit(field_data_ready_bitset_, field_id, true);
update_row_count(num_rows);
if (generated_interim_index) {
auto column = get_column(field_id);
if (column) {
column->ManualEvictCache();
}
}
if (data_type == DataType::GEOMETRY &&
segcore_config_.get_enable_geometry_cache()) {
// Construct GeometryCache for the entire field
LoadGeometryCache(field_id, column);
}
// Check if need to build ArrayOffsetsSealed for struct array fields
if (data_type == DataType::ARRAY ||
data_type == DataType::VECTOR_ARRAY) {
auto& field_meta = schema_->operator[](field_id);
const std::string& field_name = field_meta.get_name().get();
if (field_name.find('[') != std::string::npos &&
field_name.find(']') != std::string::npos) {
struct_name = field_name.substr(0, field_name.find('['));
auto it = struct_to_array_offsets_.find(struct_name);
if (it != struct_to_array_offsets_.end()) {
array_offsets_map_[field_id] = it->second;
} else {
field_meta_ptr = &field_meta; // need to build
}
}
}
}
if (data_type == DataType::GEOMETRY &&
segcore_config_.get_enable_geometry_cache()) {
// Construct GeometryCache for the entire field
LoadGeometryCache(field_id, column);
// Build ArrayOffsetsSealed outside lock (expensive operation)
if (field_meta_ptr) {
auto new_offsets =
ArrayOffsetsSealed::BuildFromSegment(this, *field_meta_ptr);
std::unique_lock lck(mutex_);
// Double-check after re-acquiring lock
auto it = struct_to_array_offsets_.find(struct_name);
if (it == struct_to_array_offsets_.end()) {
struct_to_array_offsets_[struct_name] = new_offsets;
array_offsets_map_[field_id] = new_offsets;
} else {
array_offsets_map_[field_id] = it->second;
}
}
}

View File

@ -169,6 +169,15 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
FieldId field_id,
const std::string& nested_path) const override;
std::shared_ptr<const IArrayOffsets>
GetArrayOffsets(FieldId field_id) const override {
auto it = array_offsets_map_.find(field_id);
if (it != array_offsets_map_.end()) {
return it->second;
}
return nullptr;
}
void
BulkGetJsonData(milvus::OpContext* op_ctx,
FieldId field_id,
@ -1049,6 +1058,14 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
// milvus storage internal api reader instance
std::unique_ptr<milvus_storage::api::Reader> reader_;
// ArrayOffsetsSealed for element-level filtering on array fields
// field_id -> ArrayOffsetsSealed mapping
std::unordered_map<FieldId, std::shared_ptr<ArrayOffsetsSealed>>
array_offsets_map_;
// struct_name -> ArrayOffsetsSealed mapping (temporary during load)
std::unordered_map<std::string, std::shared_ptr<ArrayOffsetsSealed>>
struct_to_array_offsets_;
};
inline SegmentSealedUPtr

View File

@ -56,6 +56,167 @@ namespace milvus::segcore {
using namespace milvus::cachinglayer;
namespace {
void
ExtractArrayLengthsFromFieldData(const std::vector<FieldDataPtr>& field_data,
const FieldMeta& field_meta,
int32_t* array_lengths) {
auto data_type = field_meta.get_data_type();
int64_t offset = 0;
for (const auto& data : field_data) {
auto num_rows = data->get_num_rows();
if (data_type == DataType::VECTOR_ARRAY) {
// Get raw pointer to VectorArray data
auto* raw_data = static_cast<const VectorArray*>(data->Data());
for (int64_t i = 0; i < num_rows; ++i) {
array_lengths[offset + i] = raw_data[i].length();
}
} else {
// For regular array types (INT32, FLOAT, etc.)
auto* raw_data = static_cast<const ArrayView*>(data->Data());
for (int64_t i = 0; i < num_rows; ++i) {
array_lengths[offset + i] = raw_data[i].length();
}
}
offset += num_rows;
}
}
void
ExtractArrayLengths(const proto::schema::FieldData& field_data,
const FieldMeta& field_meta,
int64_t num_rows,
int32_t* array_lengths) {
auto data_type = field_meta.get_data_type();
if (data_type == DataType::VECTOR_ARRAY) {
const auto& vector_array = field_data.vectors().vector_array();
int64_t dim = field_meta.get_dim();
auto element_type = field_meta.get_element_type();
for (int i = 0; i < num_rows; ++i) {
const auto& vec_field = vector_array.data(i);
int32_t array_len = 0;
switch (element_type) {
case DataType::VECTOR_FLOAT:
array_len = vec_field.float_vector().data_size() / dim;
break;
case DataType::VECTOR_FLOAT16:
array_len = vec_field.float16_vector().size() / (dim * 2);
break;
case DataType::VECTOR_BFLOAT16:
array_len = vec_field.bfloat16_vector().size() / (dim * 2);
break;
case DataType::VECTOR_BINARY:
array_len = vec_field.binary_vector().size() / (dim / 8);
break;
case DataType::VECTOR_INT8:
array_len = vec_field.int8_vector().size() / dim;
break;
default:
ThrowInfo(ErrorCode::UnexpectedError,
"Unexpected VECTOR_ARRAY element type: {}",
element_type);
}
array_lengths[i] = array_len;
}
} else {
// ARRAY: extract from scalars().array_data().data(i)
const auto& array_data = field_data.scalars().array_data();
auto element_type = field_meta.get_element_type();
for (int i = 0; i < num_rows; ++i) {
int32_t array_len = 0;
switch (element_type) {
case DataType::BOOL:
array_len = array_data.data(i).bool_data().data_size();
break;
case DataType::INT8:
case DataType::INT16:
case DataType::INT32:
array_len = array_data.data(i).int_data().data_size();
break;
case DataType::INT64:
array_len = array_data.data(i).long_data().data_size();
break;
case DataType::FLOAT:
array_len = array_data.data(i).float_data().data_size();
break;
case DataType::DOUBLE:
array_len = array_data.data(i).double_data().data_size();
break;
case DataType::STRING:
case DataType::VARCHAR:
array_len = array_data.data(i).string_data().data_size();
break;
default:
ThrowInfo(ErrorCode::UnexpectedError,
"Unexpected array type: {}",
element_type);
}
array_lengths[i] = array_len;
}
}
// Handle nullable fields
if (field_meta.is_nullable() && field_data.valid_data_size() > 0) {
const auto& valid_data = field_data.valid_data();
for (int i = 0; i < num_rows; ++i) {
if (!valid_data[i]) {
array_lengths[i] = 0; // null → empty array
}
}
}
}
} // anonymous namespace
void
SegmentGrowingImpl::InitializeArrayOffsets() {
// Group fields by struct_name
std::unordered_map<std::string, std::vector<FieldId>> struct_fields;
for (const auto& [field_id, field_meta] : schema_->get_fields()) {
const auto& field_name = field_meta.get_name().get();
// Check if field belongs to a struct: format = "struct_name[field_name]"
size_t bracket_pos = field_name.find('[');
if (bracket_pos != std::string::npos && bracket_pos > 0) {
std::string struct_name = field_name.substr(0, bracket_pos);
struct_fields[struct_name].push_back(field_id);
}
}
// Create one ArrayOffsetsGrowing per struct, shared by all its fields
for (const auto& [struct_name, field_ids] : struct_fields) {
auto array_offsets = std::make_shared<ArrayOffsetsGrowing>();
// Pick the first field as representative (any field works since array lengths are identical)
FieldId representative_field = field_ids[0];
// Map all field_ids from this struct to the same ArrayOffsetsGrowing
for (auto field_id : field_ids) {
array_offsets_map_[field_id] = array_offsets;
}
// Record representative field for Insert-time updates
struct_representative_fields_.insert(representative_field);
LOG_INFO(
"Created ArrayOffsetsGrowing for struct '{}' with {} fields, "
"representative field_id={}",
struct_name,
field_ids.size(),
representative_field.get());
}
}
int64_t
SegmentGrowingImpl::PreInsert(int64_t size) {
auto reserved_begin = insert_record_.reserved.fetch_add(size);
@ -174,6 +335,22 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
insert_record_);
}
// update ArrayOffsetsGrowing for struct fields
if (struct_representative_fields_.count(field_id) > 0) {
const auto& field_data =
insert_record_proto->fields_data(data_offset);
std::vector<int32_t> array_lengths(num_rows);
ExtractArrayLengths(
field_data, field_meta, num_rows, array_lengths.data());
auto offsets_it = array_offsets_map_.find(field_id);
if (offsets_it != array_offsets_map_.end()) {
offsets_it->second->Insert(
reserved_offset, array_lengths.data(), num_rows);
}
}
// index text.
if (field_meta.enable_match()) {
// TODO: iterate texts and call `AddText` instead of `AddTexts`. This may cost much more memory.
@ -381,6 +558,23 @@ SegmentGrowingImpl::load_field_data_common(
index->Reload();
}
// update ArrayOffsetsGrowing for struct fields
if (struct_representative_fields_.count(field_id) > 0) {
std::vector<int32_t> array_lengths(num_rows);
ExtractArrayLengthsFromFieldData(
field_data, field_meta, array_lengths.data());
auto offsets_it = array_offsets_map_.find(field_id);
if (offsets_it != array_offsets_map_.end()) {
offsets_it->second->Insert(
reserved_offset, array_lengths.data(), num_rows);
}
LOG_INFO("Updated ArrayOffsetsGrowing for field {} with {} rows",
field_id.get(),
num_rows);
}
// update the mem size
stats_.mem_size += storage::GetByteSizeOfFieldDatas(field_data);

View File

@ -34,6 +34,7 @@
#include "common/Types.h"
#include "query/PlanNode.h"
#include "common/GeometryCache.h"
#include "common/ArrayOffsets.h"
namespace milvus::segcore {
@ -330,6 +331,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
},
segment_id) {
this->CreateTextIndexes();
this->InitializeArrayOffsets();
}
~SegmentGrowingImpl() {
@ -490,6 +492,15 @@ class SegmentGrowingImpl : public SegmentGrowing {
"RemoveJsonStats not implemented for SegmentGrowingImpl");
}
std::shared_ptr<const IArrayOffsets>
GetArrayOffsets(FieldId field_id) const override {
auto it = array_offsets_map_.find(field_id);
if (it != array_offsets_map_.end()) {
return it->second;
}
return nullptr;
}
protected:
int64_t
num_chunk(FieldId field_id) const override;
@ -586,6 +597,9 @@ class SegmentGrowingImpl : public SegmentGrowing {
const std::shared_ptr<milvus_storage::api::Properties>& properties,
int64_t index);
void
InitializeArrayOffsets();
private:
storage::MmapChunkDescriptorPtr mmap_descriptor_ = nullptr;
SegcoreConfig segcore_config_;
@ -609,6 +623,15 @@ class SegmentGrowingImpl : public SegmentGrowing {
// milvus storage internal api reader instance
std::unique_ptr<milvus_storage::api::Reader> reader_;
// field_id -> ArrayOffsetsGrowing (for fast lookup via GetArrayOffsets)
// Multiple field_ids from the same struct point to the same ArrayOffsetsGrowing
std::unordered_map<FieldId, std::shared_ptr<ArrayOffsetsGrowing>>
array_offsets_map_;
// Representative field_id for each struct (used to extract array lengths during Insert)
// One field_id per struct, since all fields in the same struct have identical array lengths
std::unordered_set<FieldId> struct_representative_fields_;
};
inline SegmentGrowingPtr

View File

@ -24,6 +24,7 @@
#include <index/ScalarIndex.h>
#include "cachinglayer/CacheSlot.h"
#include "common/ArrayOffsets.h"
#include "common/EasyAssert.h"
#include "common/Json.h"
#include "common/OpContext.h"
@ -243,6 +244,11 @@ class SegmentInterface {
virtual void
Load(milvus::tracer::TraceContext& trace_ctx) = 0;
// Get IArrayOffsets for element-level filtering on array fields
// Returns nullptr if the field doesn't have IArrayOffsets
virtual std::shared_ptr<const IArrayOffsets>
GetArrayOffsets(FieldId field_id) const = 0;
};
// internal API for DSL calculation

View File

@ -67,6 +67,10 @@ GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
std::vector<float> distances(size);
std::vector<int64_t> seg_offsets(size);
std::vector<GroupByValueType> group_by_values(size);
std::vector<int32_t> element_indices;
if (search_result->element_level_) {
element_indices.resize(size);
}
uint32_t index = 0;
for (int j = 0; j < total_nq_; j++) {
@ -76,6 +80,10 @@ GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
seg_offsets[index] = search_result->seg_offsets_[offset];
group_by_values[index] =
search_result->group_by_values_.value()[offset];
if (search_result->element_level_) {
element_indices[index] =
search_result->element_indices_[offset];
}
index++;
real_topks[j]++;
}
@ -84,6 +92,9 @@ GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
search_result->group_by_values_.value().swap(group_by_values);
if (search_result->element_level_) {
search_result->element_indices_.swap(element_indices);
}
AssertInfo(search_result->primary_keys_.size() ==
search_result->group_by_values_.value().size(),
"Wrong size for group_by_values size after refresh:{}, "

View File

@ -116,12 +116,17 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
real_topks[i]++;
offsets[valid_index] = offsets[index];
distances[valid_index] = distances[index];
if (search_result->element_level_)
search_result->element_indices_[valid_index] =
search_result->element_indices_[index];
valid_index++;
}
}
}
offsets.resize(valid_index);
distances.resize(valid_index);
if (search_result->element_level_)
search_result->element_indices_.resize(valid_index);
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
std::partial_sum(real_topks.begin(),
real_topks.end(),
@ -207,6 +212,10 @@ ReduceHelper::SortEqualScoresOneNQ(size_t nq_begin,
PkType temp_pk =
std::move(search_result->primary_keys_[start + i]);
int64_t temp_offset = search_result->seg_offsets_[start + i];
int32_t temp_elem_idx =
search_result->element_level_
? search_result->element_indices_[start + i]
: -1;
size_t curr = i;
while (indices[curr] != i) {
@ -215,12 +224,20 @@ ReduceHelper::SortEqualScoresOneNQ(size_t nq_begin,
std::move(search_result->primary_keys_[start + next]);
search_result->seg_offsets_[start + curr] =
search_result->seg_offsets_[start + next];
if (search_result->element_level_) {
search_result->element_indices_[start + curr] =
search_result->element_indices_[start + next];
}
indices[curr] = curr; // Mark as processed
curr = next;
}
search_result->primary_keys_[start + curr] = std::move(temp_pk);
search_result->seg_offsets_[start + curr] = temp_offset;
if (search_result->element_level_) {
search_result->element_indices_[start + curr] =
temp_elem_idx;
}
indices[curr] = curr;
}
}
@ -258,6 +275,10 @@ ReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
search_result->distances_[offset];
search_result->seg_offsets_[index] =
search_result->seg_offsets_[offset];
if (search_result->element_level_) {
search_result->element_indices_[index] =
search_result->element_indices_[offset];
}
index++;
real_topks[j]++;
}
@ -265,6 +286,9 @@ ReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
search_result->primary_keys_.resize(index);
search_result->distances_.resize(index);
search_result->seg_offsets_.resize(index);
if (search_result->element_level_) {
search_result->element_indices_.resize(index);
}
}
void
@ -451,6 +475,15 @@ ReduceHelper::GetSearchResultDataSlice(const int slice_index,
// reserve space for distances
search_result_data->mutable_scores()->Resize(result_count, 0);
for (auto search_result : search_results_) {
if (search_result->element_level_) {
search_result_data->mutable_element_indices()
->mutable_data()
->Resize(result_count, -1);
break;
}
}
// fill pks and distances
for (auto qi = nq_begin; qi < nq_end; qi++) {
int64_t topk_count = 0;
@ -499,6 +532,13 @@ ReduceHelper::GetSearchResultDataSlice(const int slice_index,
search_result_data->mutable_scores()->Set(
loc, search_result->distances_[ki]);
if (search_result->element_level_) {
search_result_data->mutable_element_indices()
->mutable_data()
->Set(loc, search_result->element_indices_[ki]);
}
// set result offset to fill output fields data
result_pairs[loc] = {&search_result->output_fields_data_, ki};
}

View File

@ -50,6 +50,7 @@ set(MILVUS_TEST_FILES
test_rust_result.cpp
test_storage_v2_index_raw_data.cpp
test_group_by_json.cpp
test_element_filter.cpp
)
if ( NOT (INDEX_ENGINE STREQUAL "cardinal") )

View File

@ -0,0 +1,977 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include <boost/format.hpp>
#include "common/Schema.h"
#include "common/ArrayOffsets.h"
#include "query/Plan.h"
#include "test_utils/DataGen.h"
#include "test_utils/storage_test_utils.h"
#include "test_utils/cachinglayer_test_utils.h"
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
class ElementFilterSealed
: public ::testing::TestWithParam<std::tuple<bool, bool>> {
protected:
bool
use_hints() const {
return std::get<0>(GetParam());
}
bool
load_index() const {
return std::get<1>(GetParam());
}
};
TEST_P(ElementFilterSealed, RangeExpr) {
bool with_hints = use_hints();
bool with_load_index = load_index();
// Step 1: Prepare schema with array field
int dim = 4;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
DataType::VECTOR_FLOAT,
dim,
knowhere::metric::L2);
auto int_array_fid = schema->AddDebugArrayField(
"structA[price_array]", DataType::INT32, false);
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
schema->set_primary_field_id(int64_fid);
size_t N = 500;
int array_len = 3;
// Step 2: Generate test data
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
auto* field_data = raw_data.raw_->mutable_fields_data(i);
if (field_data->field_id() == int_array_fid.get()) {
field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Clear();
for (int row = 0; row < N; row++) {
auto* array_data = field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Add();
for (int elem = 0; elem < array_len; elem++) {
int value = row * array_len + elem + 1;
array_data->mutable_int_data()->mutable_data()->Add(value);
}
}
break;
}
}
// Step 3: Create sealed segment with field data
auto segment = CreateSealedWithFieldDataLoaded(schema, raw_data);
// Step 4: Load vector index for element-level search
auto array_vec_values = raw_data.get_col<VectorFieldProto>(vec_fid);
// DataGen generates VECTOR_ARRAY with data in float_vector (flattened),
// not in vector_array (nested structure)
std::vector<float> vector_data(dim * N * array_len);
for (int i = 0; i < N; i++) {
const auto& float_vec = array_vec_values[i].float_vector().data();
// float_vec contains array_len * dim floats
for (int j = 0; j < array_len * dim; j++) {
vector_data[i * array_len * dim + j] = float_vec[j];
}
}
// For element-level search, index all elements (N * array_len vectors)
auto indexing = GenVecIndexing(N * array_len,
dim,
vector_data.data(),
knowhere::IndexEnum::INDEX_HNSW);
LoadIndexInfo load_index_info;
load_index_info.field_id = vec_fid.get();
load_index_info.index_params = GenIndexParams(indexing.get());
load_index_info.cache_index =
CreateTestCacheIndex("test", std::move(indexing));
load_index_info.index_params["metric_type"] = knowhere::metric::L2;
load_index_info.field_type = DataType::VECTOR_ARRAY;
load_index_info.element_type = DataType::VECTOR_FLOAT;
if (with_load_index) {
segment->LoadIndex(load_index_info);
}
int topK = 5;
// Step 5: Test with element-level filter
// Query: Search array elements, filter by element_value in (100, 400)
{
std::string hints_line =
with_hints ? R"(hints: "iterative_filter")" : "";
std::string raw_plan = boost::str(boost::format(R"(vector_anns: <
field_id: %1%
predicates: <
element_filter_expr: <
element_expr: <
binary_range_expr: <
column_info: <
field_id: %2%
data_type: Int32
element_type: Int32
is_element_level: true
>
lower_inclusive: false
upper_inclusive: false
lower_value: <
int64_val: 100
>
upper_value: <
int64_val: 400
>
>
>
predicate: <
binary_arith_op_eval_range_expr: <
column_info: <
field_id: %3%
data_type: Int64
>
arith_op: Mod
right_operand: <
int64_val: 2
>
op: Equal
value: <
int64_val: 0
>
>
>
struct_name: "structA"
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
%4%
search_params: "{\"ef\": 50}"
>
placeholder_tag: "$0">)") %
vec_fid.get() % int_array_fid.get() %
int64_fid.get() % hints_line);
proto::plan::PlanNode plan_node;
auto ok =
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
ASSERT_TRUE(ok) << "Failed to parse element-level filter plan";
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
ASSERT_NE(plan, nullptr);
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw =
CreatePlaceholderGroup(num_queries, dim, seed, true);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
// Verify results
ASSERT_NE(search_result, nullptr);
// In element-level mode, results should be element indices, not doc offsets
ASSERT_TRUE(search_result->element_level_);
ASSERT_FALSE(search_result->element_indices_.empty());
// Also check seg_offsets_ which stores the doc IDs
ASSERT_FALSE(search_result->seg_offsets_.empty());
ASSERT_EQ(search_result->element_indices_.size(),
search_result->seg_offsets_.size());
// Should have topK results per query
ASSERT_LE(search_result->element_indices_.size(), topK * num_queries);
std::cout << "Element-level search returned:" << std::endl;
for (auto i = 0; i < search_result->seg_offsets_.size(); i++) {
int64_t doc_id = search_result->seg_offsets_[i];
int32_t elem_idx = search_result->element_indices_[i];
float distance = search_result->distances_[i];
std::cout << "doc_id: " << doc_id << ", element_index: " << elem_idx
<< ", distance: " << distance << std::endl;
// Verify the doc_id satisfies the predicate (id % 2 == 0)
ASSERT_EQ(doc_id % 2, 0) << "Result doc_id " << doc_id
<< " should satisfy (id % 2 == 0)";
// Verify element value is in range (100, 400)
// Element value = doc_id * array_len + elem_idx + 1
int element_value = doc_id * array_len + elem_idx + 1;
ASSERT_GT(element_value, 100)
<< "Element value " << element_value << " should be > 100";
ASSERT_LT(element_value, 400)
<< "Element value " << element_value << " should be < 400";
}
// Verify distances are sorted (ascending for L2)
for (size_t i = 1; i < search_result->distances_.size(); ++i) {
ASSERT_LE(search_result->distances_[i - 1],
search_result->distances_[i])
<< "Distances should be sorted in ascending order";
}
}
}
TEST_P(ElementFilterSealed, UnaryExpr) {
bool with_hints = use_hints();
bool with_load_index = load_index();
// Step 1: Prepare schema with array field
int dim = 4;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
DataType::VECTOR_FLOAT,
dim,
knowhere::metric::L2);
auto int_array_fid = schema->AddDebugArrayField(
"structA[price_array]", DataType::INT32, false);
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
schema->set_primary_field_id(int64_fid);
size_t N = 500;
int array_len = 3;
// Step 2: Generate test data
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
auto* field_data = raw_data.raw_->mutable_fields_data(i);
if (field_data->field_id() == int_array_fid.get()) {
field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Clear();
for (int row = 0; row < N; row++) {
auto* array_data = field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Add();
for (int elem = 0; elem < array_len; elem++) {
int value = row * array_len + elem + 1;
array_data->mutable_int_data()->mutable_data()->Add(value);
}
}
break;
}
}
// Step 3: Create sealed segment with field data
auto segment = CreateSealedWithFieldDataLoaded(schema, raw_data);
// Step 4: Load vector index for element-level search
auto array_vec_values = raw_data.get_col<VectorFieldProto>(vec_fid);
// DataGen generates VECTOR_ARRAY with data in float_vector (flattened),
// not in vector_array (nested structure)
std::vector<float> vector_data(dim * N * array_len);
for (int i = 0; i < N; i++) {
const auto& float_vec = array_vec_values[i].float_vector().data();
// float_vec contains array_len * dim floats
for (int j = 0; j < array_len * dim; j++) {
vector_data[i * array_len * dim + j] = float_vec[j];
}
}
// For element-level search, index all elements (N * array_len vectors)
auto indexing = GenVecIndexing(N * array_len,
dim,
vector_data.data(),
knowhere::IndexEnum::INDEX_HNSW);
LoadIndexInfo load_index_info;
load_index_info.field_id = vec_fid.get();
load_index_info.index_params = GenIndexParams(indexing.get());
load_index_info.cache_index =
CreateTestCacheIndex("test", std::move(indexing));
load_index_info.index_params["metric_type"] = knowhere::metric::L2;
load_index_info.field_type = DataType::VECTOR_ARRAY;
load_index_info.element_type = DataType::VECTOR_FLOAT;
if (with_load_index) {
segment->LoadIndex(load_index_info);
}
int topK = 5;
// Step 5: Test with element-level filter
// Query: Search array elements, filter by element_value < 10
{
std::string hints_line =
with_hints ? R"(hints: "iterative_filter")" : "";
std::string raw_plan = boost::str(boost::format(R"(vector_anns: <
field_id: %1%
predicates: <
element_filter_expr: <
element_expr: <
unary_range_expr: <
column_info: <
field_id: %2%
data_type: Int32
element_type: Int32
is_element_level: true
>
op: GreaterThan
value: <
int64_val: 10
>
>
>
predicate: <
binary_arith_op_eval_range_expr: <
column_info: <
field_id: %3%
data_type: Int64
>
arith_op: Mod
right_operand: <
int64_val: 2
>
op: Equal
value: <
int64_val: 0
>
>
>
struct_name: "structA"
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
%4%
search_params: "{\"ef\": 50}"
>
placeholder_tag: "$0">)") %
vec_fid.get() % int_array_fid.get() %
int64_fid.get() % hints_line);
proto::plan::PlanNode plan_node;
auto ok =
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
ASSERT_TRUE(ok) << "Failed to parse element-level filter plan";
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
ASSERT_NE(plan, nullptr);
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw =
CreatePlaceholderGroup(num_queries, dim, seed, true);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
// Verify results
ASSERT_NE(search_result, nullptr);
// In element-level mode, results should be element indices, not doc offsets
ASSERT_TRUE(search_result->element_level_);
ASSERT_FALSE(search_result->element_indices_.empty());
// Also check seg_offsets_ which stores the doc IDs
ASSERT_FALSE(search_result->seg_offsets_.empty());
ASSERT_EQ(search_result->element_indices_.size(),
search_result->seg_offsets_.size());
// Should have topK results per query
ASSERT_LE(search_result->element_indices_.size(), topK * num_queries);
std::cout << "Element-level search returned:" << std::endl;
for (auto i = 0; i < search_result->seg_offsets_.size(); i++) {
std::cout << "doc_id: " << search_result->seg_offsets_[i]
<< ", element_index: "
<< search_result->element_indices_[i] << std::endl;
std::cout << "distance: " << search_result->distances_[i]
<< std::endl;
}
// Verify distances are sorted (ascending for L2)
for (size_t i = 1; i < search_result->distances_.size(); ++i) {
ASSERT_LE(search_result->distances_[i - 1],
search_result->distances_[i])
<< "Distances should be sorted in ascending order";
}
}
}
INSTANTIATE_TEST_SUITE_P(
ElementFilter,
ElementFilterSealed,
::testing::Combine(::testing::Bool(), // with_hints: true/false
::testing::Bool() // with_load_index: true/false
),
[](const ::testing::TestParamInfo<ElementFilterSealed::ParamType>& info) {
bool with_hints = std::get<0>(info.param);
bool with_load_index = std::get<1>(info.param);
std::string name = "";
name += with_hints ? "WithHints" : "WithoutHints";
name += "_";
name += with_load_index ? "WithLoadIndex" : "WithoutLoadIndex";
return name;
});
TEST(ElementFilter, GrowingSegmentArrayOffsetsGrowing) {
int dim = 4;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
DataType::VECTOR_FLOAT,
dim,
knowhere::metric::L2);
auto int_array_fid = schema->AddDebugArrayField(
"structA[price_array]", DataType::INT32, false);
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
schema->set_primary_field_id(int64_fid);
size_t N = 500;
int array_len = 3;
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
// Customize int_array data: doc i has elements [i*3+1, i*3+2, i*3+3]
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
auto* field_data = raw_data.raw_->mutable_fields_data(i);
if (field_data->field_id() == int_array_fid.get()) {
field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Clear();
for (int row = 0; row < N; row++) {
auto* array_data = field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Add();
for (int elem = 0; elem < array_len; elem++) {
int value = row * array_len + elem + 1;
array_data->mutable_int_data()->mutable_data()->Add(value);
}
}
break;
}
}
auto segment = CreateGrowingSegment(schema, empty_index_meta);
segment->PreInsert(N);
segment->Insert(0,
N,
raw_data.row_ids_.data(),
raw_data.timestamps_.data(),
raw_data.raw_);
auto growing_impl = dynamic_cast<SegmentGrowingImpl*>(segment.get());
ASSERT_NE(growing_impl, nullptr);
// Both fields should share the same ArrayOffsetsGrowing
auto offsets_vec = growing_impl->GetArrayOffsets(vec_fid);
auto offsets_int = growing_impl->GetArrayOffsets(int_array_fid);
ASSERT_NE(offsets_vec, nullptr);
ASSERT_NE(offsets_int, nullptr);
// Should point to the same object (shared)
ASSERT_EQ(offsets_vec, offsets_int)
<< "Fields in same struct should share ArrayOffsetsGrowing";
// Verify counts
ASSERT_EQ(offsets_vec->GetRowCount(), N)
<< "Should have " << N << " documents";
ASSERT_EQ(offsets_vec->GetTotalElementCount(), N * array_len)
<< "Should have " << N * array_len << " total elements";
for (int64_t doc_id = 0; doc_id < N; ++doc_id) {
for (int32_t elem_idx = 0; elem_idx < array_len; ++elem_idx) {
int64_t elem_id = doc_id * array_len + elem_idx;
auto [mapped_doc, mapped_idx] =
offsets_vec->ElementIDToRowID(elem_id);
ASSERT_EQ(mapped_doc, doc_id)
<< "Element " << elem_id << " should map to doc " << doc_id;
ASSERT_EQ(mapped_idx, elem_idx)
<< "Element " << elem_id << " should have index " << elem_idx;
}
}
}
TEST(ElementFilter, GrowingSegmentOutOfOrderInsert) {
// Test out-of-order Insert handling in ArrayOffsetsGrowing
int dim = 4;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
DataType::VECTOR_FLOAT,
dim,
knowhere::metric::L2);
auto int_array_fid = schema->AddDebugArrayField(
"structA[price_array]", DataType::INT32, false);
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
schema->set_primary_field_id(int64_fid);
int array_len = 3;
// Create growing segment
auto segment = CreateGrowingSegment(schema, empty_index_meta);
// Simulate out-of-order inserts
// Insert docs [10-19], [0-9], [20-29]
auto gen_batch = [&](int64_t start, int64_t count) {
auto batch = DataGen(schema, count, 42 + start, start, 1, array_len);
// Customize int_array data
for (int i = 0; i < batch.raw_->fields_data_size(); i++) {
auto* field_data = batch.raw_->mutable_fields_data(i);
if (field_data->field_id() == int_array_fid.get()) {
field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Clear();
for (int row = 0; row < count; row++) {
int64_t global_row = start + row;
auto* array_data = field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Add();
for (int elem = 0; elem < array_len; elem++) {
int value = global_row * array_len + elem + 1;
array_data->mutable_int_data()->mutable_data()->Add(
value);
}
}
break;
}
}
return batch;
};
// Insert batch 2 first (docs 10-19) - should be cached
auto batch2 = gen_batch(10, 10);
segment->PreInsert(10);
segment->Insert(
10, 10, batch2.row_ids_.data(), batch2.timestamps_.data(), batch2.raw_);
// Insert batch 1 (docs 0-9) - should trigger drain of batch 2
auto batch1 = gen_batch(0, 10);
segment->PreInsert(10);
segment->Insert(
0, 10, batch1.row_ids_.data(), batch1.timestamps_.data(), batch1.raw_);
// Insert batch 3 (docs 25-34) - should be cached (gap at 20-24)
auto batch3 = gen_batch(25, 10);
segment->PreInsert(10);
segment->Insert(
25, 10, batch3.row_ids_.data(), batch3.timestamps_.data(), batch3.raw_);
// Verify ArrayOffsetsGrowing
auto growing_impl = dynamic_cast<SegmentGrowingImpl*>(segment.get());
ASSERT_NE(growing_impl, nullptr);
auto offsets = growing_impl->GetArrayOffsets(vec_fid);
ASSERT_NE(offsets, nullptr);
// After inserting docs [0-19] (batch3 cached due to gap), committed count should be 20
ASSERT_EQ(offsets->GetRowCount(), 20)
<< "Should have committed docs 0-19, batch3 cached";
ASSERT_EQ(offsets->GetTotalElementCount(), 20 * array_len)
<< "Should have 20 docs worth of elements";
// Verify mapping for committed docs
for (int64_t doc_id = 0; doc_id < 20; ++doc_id) {
for (int32_t elem_idx = 0; elem_idx < array_len; ++elem_idx) {
int64_t elem_id = doc_id * array_len + elem_idx;
auto [mapped_doc, mapped_idx] = offsets->ElementIDToRowID(elem_id);
ASSERT_EQ(mapped_doc, doc_id)
<< "Element " << elem_id << " should map to doc " << doc_id;
ASSERT_EQ(mapped_idx, elem_idx);
}
}
}
// Parameterized test fixture for GrowingIterativeRangeExpr
class ElementFilterGrowing : public ::testing::TestWithParam<bool> {
protected:
bool
use_hints() const {
return GetParam();
}
};
TEST_P(ElementFilterGrowing, RangeExpr) {
bool with_hints = use_hints();
int dim = 4;
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugVectorArrayField("structA[array_float_vec]",
DataType::VECTOR_FLOAT,
dim,
knowhere::metric::L2);
auto int_array_fid = schema->AddDebugArrayField(
"structA[price_array]", DataType::INT32, false);
auto int64_fid = schema->AddDebugField("id", DataType::INT64);
schema->set_primary_field_id(int64_fid);
size_t N = 500;
int array_len = 3;
// Generate test data
auto raw_data = DataGen(schema, N, 42, 0, 1, array_len);
// Customize int_array data: doc i has elements [i*3+1, i*3+2, i*3+3]
for (int i = 0; i < raw_data.raw_->fields_data_size(); i++) {
auto* field_data = raw_data.raw_->mutable_fields_data(i);
if (field_data->field_id() == int_array_fid.get()) {
field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Clear();
for (int row = 0; row < N; row++) {
auto* array_data = field_data->mutable_scalars()
->mutable_array_data()
->mutable_data()
->Add();
for (int elem = 0; elem < array_len; elem++) {
int value = row * array_len + elem + 1;
array_data->mutable_int_data()->mutable_data()->Add(value);
}
}
break;
}
}
// Create growing segment and insert data
auto segment = CreateGrowingSegment(schema, empty_index_meta);
segment->PreInsert(N);
segment->Insert(0,
N,
raw_data.row_ids_.data(),
raw_data.timestamps_.data(),
raw_data.raw_);
// Verify ArrayOffsetsGrowing was built
auto growing_impl = dynamic_cast<SegmentGrowingImpl*>(segment.get());
ASSERT_NE(growing_impl, nullptr);
auto offsets = growing_impl->GetArrayOffsets(vec_fid);
ASSERT_NE(offsets, nullptr);
ASSERT_EQ(offsets->GetRowCount(), N);
ASSERT_EQ(offsets->GetTotalElementCount(), N * array_len);
int topK = 5;
// Execute element-level search with iterative filter
// Query: Search array elements where (id % 2 == 0) AND (price_array element in range (100, 400))
{
std::string hints_line =
with_hints ? R"(hints: "iterative_filter")" : "";
std::string raw_plan = boost::str(boost::format(R"(vector_anns: <
field_id: %1%
predicates: <
element_filter_expr: <
element_expr: <
binary_range_expr: <
column_info: <
field_id: %2%
data_type: Int32
element_type: Int32
is_element_level: true
>
lower_inclusive: false
upper_inclusive: false
lower_value: <
int64_val: 100
>
upper_value: <
int64_val: 400
>
>
>
predicate: <
binary_arith_op_eval_range_expr: <
column_info: <
field_id: %3%
data_type: Int64
>
arith_op: Mod
right_operand: <
int64_val: 2
>
op: Equal
value: <
int64_val: 0
>
>
>
struct_name: "structA"
>
>
query_info: <
topk: 5
round_decimal: 3
metric_type: "L2"
%4%
search_params: "{\"ef\": 50}"
>
placeholder_tag: "$0">)") %
vec_fid.get() % int_array_fid.get() %
int64_fid.get() % hints_line);
proto::plan::PlanNode plan_node;
auto ok =
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
ASSERT_TRUE(ok) << "Failed to parse element-level filter plan";
auto plan = CreateSearchPlanFromPlanNode(schema, plan_node);
ASSERT_NE(plan, nullptr);
auto num_queries = 1;
auto seed = 1024;
auto ph_group_raw =
CreatePlaceholderGroup(num_queries, dim, seed, true);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63);
// Verify results
ASSERT_NE(search_result, nullptr);
// In element-level mode, results should contain element indices
ASSERT_TRUE(search_result->element_level_)
<< "Search should be in element-level mode";
ASSERT_FALSE(search_result->element_indices_.empty())
<< "Should have element indices";
ASSERT_FALSE(search_result->seg_offsets_.empty())
<< "Should have doc offsets";
ASSERT_EQ(search_result->element_indices_.size(),
search_result->seg_offsets_.size())
<< "Element indices and doc offsets should match in size";
// Should have topK results per query
ASSERT_LE(search_result->element_indices_.size(), topK * num_queries)
<< "Should not exceed topK results";
std::cout << "Growing segment element-level search results:"
<< std::endl;
for (size_t i = 0; i < search_result->seg_offsets_.size(); i++) {
int64_t doc_id = search_result->seg_offsets_[i];
int32_t elem_idx = search_result->element_indices_[i];
float distance = search_result->distances_[i];
std::cout << " [" << i << "] doc_id=" << doc_id
<< ", element_index=" << elem_idx
<< ", distance=" << distance << std::endl;
// Verify the doc_id satisfies the predicate (id % 2 == 0)
ASSERT_EQ(doc_id % 2, 0) << "Result doc_id " << doc_id
<< " should satisfy (id % 2 == 0)";
// Verify element_idx is valid
ASSERT_GE(elem_idx, 0) << "Element index should be >= 0";
ASSERT_LT(elem_idx, array_len)
<< "Element index should be < array_len";
// Verify element value is in range (100, 400)
// Element value = doc_id * array_len + elem_idx + 1
int element_value = doc_id * array_len + elem_idx + 1;
ASSERT_GT(element_value, 100)
<< "Element value " << element_value << " should be > 100";
ASSERT_LT(element_value, 400)
<< "Element value " << element_value << " should be < 400";
}
// Verify distances are sorted (ascending for L2)
for (size_t i = 1; i < search_result->distances_.size(); ++i) {
ASSERT_LE(search_result->distances_[i - 1],
search_result->distances_[i])
<< "Distances should be sorted in ascending order";
}
}
}
INSTANTIATE_TEST_SUITE_P(
ElementFilter,
ElementFilterGrowing,
::testing::Bool(), // with_hints: true/false
[](const ::testing::TestParamInfo<ElementFilterGrowing::ParamType>& info) {
bool with_hints = info.param;
return with_hints ? "WithHints" : "WithoutHints";
});
// Unit tests for ArrayOffsetsGrowing
TEST(ArrayOffsetsGrowing, PurePendingThenDrain) {
// Test: first insert goes entirely to pending, second insert triggers drain
ArrayOffsetsGrowing offsets;
// First insert: rows 2-4, all go to pending (committed_row_count_ = 0)
std::vector<int32_t> lens1 = {
3, 2, 4}; // row 2: 3 elems, row 3: 2 elems, row 4: 4 elems
offsets.Insert(2, lens1.data(), 3);
ASSERT_EQ(offsets.GetRowCount(), 0) << "No rows should be committed yet";
ASSERT_EQ(offsets.GetTotalElementCount(), 0)
<< "No elements should exist yet";
// Second insert: rows 0-1, triggers drain of pending rows 2-4
std::vector<int32_t> lens2 = {2, 3}; // row 0: 2 elems, row 1: 3 elems
offsets.Insert(0, lens2.data(), 2);
ASSERT_EQ(offsets.GetRowCount(), 5) << "All 5 rows should be committed";
// Total elements: 2 + 3 + 3 + 2 + 4 = 14
ASSERT_EQ(offsets.GetTotalElementCount(), 14);
// Verify ElementIDToRowID mapping
// Row 0: elem 0-1, Row 1: elem 2-4, Row 2: elem 5-7, Row 3: elem 8-9, Row 4: elem 10-13
std::vector<std::pair<int32_t, int32_t>> expected = {
{0, 0},
{0, 1}, // row 0
{1, 0},
{1, 1},
{1, 2}, // row 1
{2, 0},
{2, 1},
{2, 2}, // row 2
{3, 0},
{3, 1}, // row 3
{4, 0},
{4, 1},
{4, 2},
{4, 3} // row 4
};
for (int32_t elem_id = 0; elem_id < 14; ++elem_id) {
auto [row_id, elem_idx] = offsets.ElementIDToRowID(elem_id);
ASSERT_EQ(row_id, expected[elem_id].first)
<< "elem_id " << elem_id << " should map to row "
<< expected[elem_id].first;
ASSERT_EQ(elem_idx, expected[elem_id].second)
<< "elem_id " << elem_id << " should have elem_idx "
<< expected[elem_id].second;
}
}
TEST(ArrayOffsetsGrowing, ElementIDRangeOfRow) {
ArrayOffsetsGrowing offsets;
// Insert 4 rows with varying element counts
std::vector<int32_t> lens = {3, 0, 2, 5}; // includes empty array
offsets.Insert(0, lens.data(), 4);
ASSERT_EQ(offsets.GetRowCount(), 4);
ASSERT_EQ(offsets.GetTotalElementCount(), 10); // 3 + 0 + 2 + 5
// Verify ElementIDRangeOfRow
auto [start0, end0] = offsets.ElementIDRangeOfRow(0);
ASSERT_EQ(start0, 0);
ASSERT_EQ(end0, 3);
auto [start1, end1] = offsets.ElementIDRangeOfRow(1);
ASSERT_EQ(start1, 3);
ASSERT_EQ(end1, 3); // empty array
auto [start2, end2] = offsets.ElementIDRangeOfRow(2);
ASSERT_EQ(start2, 3);
ASSERT_EQ(end2, 5);
auto [start3, end3] = offsets.ElementIDRangeOfRow(3);
ASSERT_EQ(start3, 5);
ASSERT_EQ(end3, 10);
// Boundary: row_id == row_count returns (total, total)
auto [start4, end4] = offsets.ElementIDRangeOfRow(4);
ASSERT_EQ(start4, 10);
ASSERT_EQ(end4, 10);
}
TEST(ArrayOffsetsGrowing, MultiplePendingBatches) {
// Test multiple pending batches being drained in order
ArrayOffsetsGrowing offsets;
// Insert row 5 first
std::vector<int32_t> lens5 = {2};
offsets.Insert(5, lens5.data(), 1);
ASSERT_EQ(offsets.GetRowCount(), 0);
// Insert row 3
std::vector<int32_t> lens3 = {3};
offsets.Insert(3, lens3.data(), 1);
ASSERT_EQ(offsets.GetRowCount(), 0);
// Insert row 1
std::vector<int32_t> lens1 = {1};
offsets.Insert(1, lens1.data(), 1);
ASSERT_EQ(offsets.GetRowCount(), 0);
// Insert row 0 - should drain row 1, but not 3 or 5 (gap at 2)
std::vector<int32_t> lens0 = {2};
offsets.Insert(0, lens0.data(), 1);
ASSERT_EQ(offsets.GetRowCount(), 2) << "Should commit rows 0-1";
ASSERT_EQ(offsets.GetTotalElementCount(), 3); // 2 + 1
// Insert row 2 - should drain rows 3, but not 5 (gap at 4)
std::vector<int32_t> lens2 = {1};
offsets.Insert(2, lens2.data(), 1);
ASSERT_EQ(offsets.GetRowCount(), 4) << "Should commit rows 0-3";
ASSERT_EQ(offsets.GetTotalElementCount(), 7); // 2 + 1 + 1 + 3
// Insert row 4 - should drain row 5
std::vector<int32_t> lens4 = {2};
offsets.Insert(4, lens4.data(), 1);
ASSERT_EQ(offsets.GetRowCount(), 6) << "Should commit rows 0-5";
ASSERT_EQ(offsets.GetTotalElementCount(), 11); // 2 + 1 + 1 + 3 + 2 + 2
// Verify final mapping
// Row 0: elem 0-1, Row 1: elem 2, Row 2: elem 3, Row 3: elem 4-6, Row 4: elem 7-8, Row 5: elem 9-10
auto [r0, i0] = offsets.ElementIDToRowID(0);
ASSERT_EQ(r0, 0);
ASSERT_EQ(i0, 0);
auto [r2, i2] = offsets.ElementIDToRowID(2);
ASSERT_EQ(r2, 1);
ASSERT_EQ(i2, 0);
auto [r4, i4] = offsets.ElementIDToRowID(4);
ASSERT_EQ(r4, 3);
ASSERT_EQ(i4, 0);
auto [r7, i7] = offsets.ElementIDToRowID(7);
ASSERT_EQ(r7, 4);
ASSERT_EQ(i7, 0);
auto [r10, i10] = offsets.ElementIDToRowID(10);
ASSERT_EQ(r10, 5);
ASSERT_EQ(i10, 1);
}

View File

@ -191,7 +191,7 @@ TEST(Indexing, BinaryBruteForce) {
SearchResult sr;
sr.total_nq_ = num_queries;
sr.unity_topK_ = topk;
sr.seg_offsets_ = std::move(sub_result.mutable_seg_offsets());
sr.seg_offsets_ = std::move(sub_result.mutable_offsets());
sr.distances_ = std::move(sub_result.mutable_distances());
auto json = SearchResultToJson(sr);

View File

@ -1579,7 +1579,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
for (auto q = 0; q < num_queries; q++) {
for (auto k = 0; k < topk; k++) {
auto offset = q * topk + k;
auto seg_offset = sub_result.get_seg_offsets()[offset];
auto seg_offset = sub_result.get_offsets()[offset];
ASSERT_EQ(std::get<std::string>(sr->primary_keys_[offset]),
str_col[seg_offset]);
ASSERT_EQ(retrieved_str_col[offset], str_col[seg_offset]);

View File

@ -1151,7 +1151,10 @@ CreatePlaceholderGroup(int64_t num_queries,
template <class TraitType = milvus::FloatVector>
auto
CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
CreatePlaceholderGroup(int64_t num_queries,
int dim,
int64_t seed = 42,
bool element_level = false) {
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
assert(dim % 8 == 0);
}
@ -1162,6 +1165,7 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(TraitType::placeholder_type);
value->set_element_level(element_level);
// TODO caiyd: need update for Int8Vector
std::normal_distribution<double> dis(0, 1);
std::default_random_engine e(seed);

View File

@ -3,53 +3,55 @@ grammar Plan;
expr:
Identifier (op1=(ADD | SUB) INTERVAL interval_string=StringLiteral)? op2=(LT | LE | GT | GE | EQ | NE) ISO compare_string=StringLiteral # TimestamptzCompareForward
| ISO compare_string=StringLiteral op2=(LT | LE | GT | GE | EQ | NE) Identifier (op1=(ADD | SUB) INTERVAL interval_string=StringLiteral)? # TimestamptzCompareReverse
| IntegerConstant # Integer
| FloatingConstant # Floating
| BooleanConstant # Boolean
| StringLiteral # String
| (Identifier|Meta) # Identifier
| JSONIdentifier # JSONIdentifier
| LBRACE Identifier RBRACE # TemplateVariable
| '(' expr ')' # Parens
| '[' expr (',' expr)* ','? ']' # Array
| EmptyArray # EmptyArray
| EXISTS expr # Exists
| expr LIKE StringLiteral # Like
| TEXTMATCH'('Identifier',' StringLiteral (',' textMatchOption)? ')' # TextMatch
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
| RANDOMSAMPLE'(' expr ')' # RandomSample
| expr POW expr # Power
| op = (ADD | SUB | BNOT | NOT) expr # Unary
// | '(' typeName ')' expr # Cast
| expr op = (MUL | DIV | MOD) expr # MulDivMod
| expr op = (ADD | SUB) expr # AddSub
| expr op = (SHL | SHR) expr # Shift
| expr op = NOT? IN expr # Term
| (JSONContains | ArrayContains)'('expr',' expr')' # JSONContains
| (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll
| (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny
| STEuqals'('Identifier','StringLiteral')' # STEuqals
| STTouches'('Identifier','StringLiteral')' # STTouches
| STOverlaps'('Identifier','StringLiteral')' # STOverlaps
| STCrosses'('Identifier','StringLiteral')' # STCrosses
| STContains'('Identifier','StringLiteral')' # STContains
| STIntersects'('Identifier','StringLiteral')' # STIntersects
| STWithin'('Identifier','StringLiteral')' # STWithin
| STDWithin'('Identifier','StringLiteral',' expr')' # STDWithin
| STIsValid'('Identifier')' # STIsValid
| ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength
| Identifier '(' ( expr (',' expr )* ','? )? ')' # Call
| expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range
| expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange
| expr op = (LT | LE | GT | GE) expr # Relational
| expr op = (EQ | NE) expr # Equality
| expr BAND expr # BitAnd
| expr BXOR expr # BitXor
| expr BOR expr # BitOr
| expr AND expr # LogicalAnd
| expr OR expr # LogicalOr
| (Identifier | JSONIdentifier) ISNULL # IsNull
| (Identifier | JSONIdentifier) ISNOTNULL # IsNotNull;
| IntegerConstant # Integer
| FloatingConstant # Floating
| BooleanConstant # Boolean
| StringLiteral # String
| (Identifier|Meta) # Identifier
| JSONIdentifier # JSONIdentifier
| StructSubFieldIdentifier # StructSubField
| LBRACE Identifier RBRACE # TemplateVariable
| '(' expr ')' # Parens
| '[' expr (',' expr)* ','? ']' # Array
| EmptyArray # EmptyArray
| EXISTS expr # Exists
| expr LIKE StringLiteral # Like
| TEXTMATCH'('Identifier',' StringLiteral (',' textMatchOption)? ')' # TextMatch
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
| RANDOMSAMPLE'(' expr ')' # RandomSample
| ElementFilter'('Identifier',' expr')' # ElementFilter
| expr POW expr # Power
| op = (ADD | SUB | BNOT | NOT) expr # Unary
// | '(' typeName ')' expr # Cast
| expr op = (MUL | DIV | MOD) expr # MulDivMod
| expr op = (ADD | SUB) expr # AddSub
| expr op = (SHL | SHR) expr # Shift
| expr op = NOT? IN expr # Term
| (JSONContains | ArrayContains)'('expr',' expr')' # JSONContains
| (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll
| (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny
| STEuqals'('Identifier','StringLiteral')' # STEuqals
| STTouches'('Identifier','StringLiteral')' # STTouches
| STOverlaps'('Identifier','StringLiteral')' # STOverlaps
| STCrosses'('Identifier','StringLiteral')' # STCrosses
| STContains'('Identifier','StringLiteral')' # STContains
| STIntersects'('Identifier','StringLiteral')' # STIntersects
| STWithin'('Identifier','StringLiteral')' # STWithin
| STDWithin'('Identifier','StringLiteral',' expr')' # STDWithin
| STIsValid'('Identifier')' # STIsValid
| ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength
| Identifier '(' ( expr (',' expr )* ','? )? ')' # Call
| expr op1 = (LT | LE) (Identifier | JSONIdentifier | StructSubFieldIdentifier) op2 = (LT | LE) expr # Range
| expr op1 = (GT | GE) (Identifier | JSONIdentifier | StructSubFieldIdentifier) op2 = (GT | GE) expr # ReverseRange
| expr op = (LT | LE | GT | GE) expr # Relational
| expr op = (EQ | NE) expr # Equality
| expr BAND expr # BitAnd
| expr BXOR expr # BitXor
| expr BOR expr # BitOr
| expr AND expr # LogicalAnd
| expr OR expr # LogicalOr
| (Identifier | JSONIdentifier) ISNULL # IsNull
| (Identifier | JSONIdentifier) ISNOTNULL # IsNotNull;
textMatchOption:
MINIMUM_SHOULD_MATCH ASSIGN IntegerConstant;
@ -115,6 +117,7 @@ ArrayContains: 'array_contains' | 'ARRAY_CONTAINS';
ArrayContainsAll: 'array_contains_all' | 'ARRAY_CONTAINS_ALL';
ArrayContainsAny: 'array_contains_any' | 'ARRAY_CONTAINS_ANY';
ArrayLength: 'array_length' | 'ARRAY_LENGTH';
ElementFilter: 'element_filter' | 'ELEMENT_FILTER';
STEuqals:'st_equals' | 'ST_EQUALS';
STTouches:'st_touches' | 'ST_TOUCHES';
@ -143,6 +146,7 @@ Meta: '$meta';
StringLiteral: EncodingPrefix? ('"' DoubleSCharSequence? '"' | '\'' SingleSCharSequence? '\'');
JSONIdentifier: (Identifier | Meta)('[' (StringLiteral | DecimalConstant) ']')+;
StructSubFieldIdentifier: '$[' Identifier ']';
fragment EncodingPrefix: 'u8' | 'u' | 'U' | 'L';

View File

@ -40,6 +40,14 @@ func FillExpressionValue(expr *planpb.Expr, templateValues map[string]*planpb.Ge
return FillJSONContainsExpressionValue(e.JsonContainsExpr, templateValues)
case *planpb.Expr_RandomSampleExpr:
return FillExpressionValue(expr.GetExpr().(*planpb.Expr_RandomSampleExpr).RandomSampleExpr.GetPredicate(), templateValues)
case *planpb.Expr_ElementFilterExpr:
if err := FillExpressionValue(e.ElementFilterExpr.GetElementExpr(), templateValues); err != nil {
return err
}
if e.ElementFilterExpr.GetPredicate() != nil {
return FillExpressionValue(e.ElementFilterExpr.GetPredicate(), templateValues)
}
return nil
default:
return fmt.Errorf("this expression no need to fill placeholder with expr type: %T", e)
}

File diff suppressed because one or more lines are too long

View File

@ -46,24 +46,26 @@ ArrayContains=45
ArrayContainsAll=46
ArrayContainsAny=47
ArrayLength=48
STEuqals=49
STTouches=50
STOverlaps=51
STCrosses=52
STContains=53
STIntersects=54
STWithin=55
STDWithin=56
STIsValid=57
BooleanConstant=58
IntegerConstant=59
FloatingConstant=60
Identifier=61
Meta=62
StringLiteral=63
JSONIdentifier=64
Whitespace=65
Newline=66
ElementFilter=49
STEuqals=50
STTouches=51
STOverlaps=52
STCrosses=53
STContains=54
STIntersects=55
STWithin=56
STDWithin=57
STIsValid=58
BooleanConstant=59
IntegerConstant=60
FloatingConstant=61
Identifier=62
Meta=63
StringLiteral=64
JSONIdentifier=65
StructSubFieldIdentifier=66
Whitespace=67
Newline=68
'('=1
')'=2
'['=3
@ -90,4 +92,4 @@ Newline=66
'|'=32
'^'=33
'~'=38
'$meta'=62
'$meta'=63

File diff suppressed because one or more lines are too long

View File

@ -46,24 +46,26 @@ ArrayContains=45
ArrayContainsAll=46
ArrayContainsAny=47
ArrayLength=48
STEuqals=49
STTouches=50
STOverlaps=51
STCrosses=52
STContains=53
STIntersects=54
STWithin=55
STDWithin=56
STIsValid=57
BooleanConstant=58
IntegerConstant=59
FloatingConstant=60
Identifier=61
Meta=62
StringLiteral=63
JSONIdentifier=64
Whitespace=65
Newline=66
ElementFilter=49
STEuqals=50
STTouches=51
STOverlaps=52
STCrosses=53
STContains=54
STIntersects=55
STWithin=56
STDWithin=57
STIsValid=58
BooleanConstant=59
IntegerConstant=60
FloatingConstant=61
Identifier=62
Meta=63
StringLiteral=64
JSONIdentifier=65
StructSubFieldIdentifier=66
Whitespace=67
Newline=68
'('=1
')'=2
'['=3
@ -90,4 +92,4 @@ Newline=66
'|'=32
'^'=33
'~'=38
'$meta'=62
'$meta'=63

View File

@ -7,18 +7,6 @@ type BasePlanVisitor struct {
*antlr.BaseParseTreeVisitor
}
func (v *BasePlanVisitor) VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitRandomSample(ctx *RandomSampleContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitParens(ctx *ParensContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitString(ctx *StringContext) interface{} {
return v.VisitChildren(ctx)
}
@ -27,22 +15,10 @@ func (v *BasePlanVisitor) VisitFloating(ctx *FloatingContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitLogicalOr(ctx *LogicalOrContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitIsNotNull(ctx *IsNotNullContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMulDivMod(ctx *MulDivModContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitIdentifier(ctx *IdentifierContext) interface{} {
return v.VisitChildren(ctx)
}
@ -55,14 +31,6 @@ func (v *BasePlanVisitor) VisitLike(ctx *LikeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitLogicalAnd(ctx *LogicalAndContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTemplateVariable(ctx *TemplateVariableContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitEquality(ctx *EqualityContext) interface{} {
return v.VisitChildren(ctx)
}
@ -71,14 +39,6 @@ func (v *BasePlanVisitor) VisitBoolean(ctx *BooleanContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTDWithin(ctx *STDWithinContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitShift(ctx *ShiftContext) interface{} {
return v.VisitChildren(ctx)
}
@ -87,54 +47,26 @@ func (v *BasePlanVisitor) VisitTimestamptzCompareForward(ctx *TimestamptzCompare
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTCrosses(ctx *STCrossesContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitOr(ctx *BitOrContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitEmptyArray(ctx *EmptyArrayContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitAddSub(ctx *AddSubContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitPhraseMatch(ctx *PhraseMatchContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitRelational(ctx *RelationalContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitArrayLength(ctx *ArrayLengthContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTextMatch(ctx *TextMatchContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTTouches(ctx *STTouchesContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTContains(ctx *STContainsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTerm(ctx *TermContext) interface{} {
return v.VisitChildren(ctx)
}
@ -151,6 +83,94 @@ func (v *BasePlanVisitor) VisitRange(ctx *RangeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitXor(ctx *BitXorContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitElementFilter(ctx *ElementFilterContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitAnd(ctx *BitAndContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitRandomSample(ctx *RandomSampleContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitParens(ctx *ParensContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitLogicalOr(ctx *LogicalOrContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMulDivMod(ctx *MulDivModContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitLogicalAnd(ctx *LogicalAndContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTemplateVariable(ctx *TemplateVariableContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTDWithin(ctx *STDWithinContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTCrosses(ctx *STCrossesContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitOr(ctx *BitOrContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitAddSub(ctx *AddSubContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitRelational(ctx *RelationalContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTextMatch(ctx *TextMatchContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTContains(ctx *STContainsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitUnary(ctx *UnaryContext) interface{} {
return v.VisitChildren(ctx)
}
@ -167,22 +187,10 @@ func (v *BasePlanVisitor) VisitJSONContainsAny(ctx *JSONContainsAnyContext) inte
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitXor(ctx *BitXorContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitExists(ctx *ExistsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitAnd(ctx *BitAndContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTEuqals(ctx *STEuqalsContext) interface{} {
return v.VisitChildren(ctx)
}
@ -191,11 +199,11 @@ func (v *BasePlanVisitor) VisitIsNull(ctx *IsNullContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitPower(ctx *PowerContext) interface{} {
func (v *BasePlanVisitor) VisitStructSubField(ctx *StructSubFieldContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
func (v *BasePlanVisitor) VisitPower(ctx *PowerContext) interface{} {
return v.VisitChildren(ctx)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -7,33 +7,15 @@ import "github.com/antlr4-go/antlr/v4"
type PlanVisitor interface {
antlr.ParseTreeVisitor
// Visit a parse tree produced by PlanParser#JSONIdentifier.
VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{}
// Visit a parse tree produced by PlanParser#RandomSample.
VisitRandomSample(ctx *RandomSampleContext) interface{}
// Visit a parse tree produced by PlanParser#Parens.
VisitParens(ctx *ParensContext) interface{}
// Visit a parse tree produced by PlanParser#String.
VisitString(ctx *StringContext) interface{}
// Visit a parse tree produced by PlanParser#Floating.
VisitFloating(ctx *FloatingContext) interface{}
// Visit a parse tree produced by PlanParser#JSONContainsAll.
VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{}
// Visit a parse tree produced by PlanParser#LogicalOr.
VisitLogicalOr(ctx *LogicalOrContext) interface{}
// Visit a parse tree produced by PlanParser#IsNotNull.
VisitIsNotNull(ctx *IsNotNullContext) interface{}
// Visit a parse tree produced by PlanParser#MulDivMod.
VisitMulDivMod(ctx *MulDivModContext) interface{}
// Visit a parse tree produced by PlanParser#Identifier.
VisitIdentifier(ctx *IdentifierContext) interface{}
@ -43,66 +25,33 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#Like.
VisitLike(ctx *LikeContext) interface{}
// Visit a parse tree produced by PlanParser#LogicalAnd.
VisitLogicalAnd(ctx *LogicalAndContext) interface{}
// Visit a parse tree produced by PlanParser#TemplateVariable.
VisitTemplateVariable(ctx *TemplateVariableContext) interface{}
// Visit a parse tree produced by PlanParser#Equality.
VisitEquality(ctx *EqualityContext) interface{}
// Visit a parse tree produced by PlanParser#Boolean.
VisitBoolean(ctx *BooleanContext) interface{}
// Visit a parse tree produced by PlanParser#TimestamptzCompareReverse.
VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{}
// Visit a parse tree produced by PlanParser#STDWithin.
VisitSTDWithin(ctx *STDWithinContext) interface{}
// Visit a parse tree produced by PlanParser#Shift.
VisitShift(ctx *ShiftContext) interface{}
// Visit a parse tree produced by PlanParser#TimestamptzCompareForward.
VisitTimestamptzCompareForward(ctx *TimestamptzCompareForwardContext) interface{}
// Visit a parse tree produced by PlanParser#Call.
VisitCall(ctx *CallContext) interface{}
// Visit a parse tree produced by PlanParser#STCrosses.
VisitSTCrosses(ctx *STCrossesContext) interface{}
// Visit a parse tree produced by PlanParser#ReverseRange.
VisitReverseRange(ctx *ReverseRangeContext) interface{}
// Visit a parse tree produced by PlanParser#BitOr.
VisitBitOr(ctx *BitOrContext) interface{}
// Visit a parse tree produced by PlanParser#EmptyArray.
VisitEmptyArray(ctx *EmptyArrayContext) interface{}
// Visit a parse tree produced by PlanParser#AddSub.
VisitAddSub(ctx *AddSubContext) interface{}
// Visit a parse tree produced by PlanParser#PhraseMatch.
VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
// Visit a parse tree produced by PlanParser#Relational.
VisitRelational(ctx *RelationalContext) interface{}
// Visit a parse tree produced by PlanParser#ArrayLength.
VisitArrayLength(ctx *ArrayLengthContext) interface{}
// Visit a parse tree produced by PlanParser#TextMatch.
VisitTextMatch(ctx *TextMatchContext) interface{}
// Visit a parse tree produced by PlanParser#STTouches.
VisitSTTouches(ctx *STTouchesContext) interface{}
// Visit a parse tree produced by PlanParser#STContains.
VisitSTContains(ctx *STContainsContext) interface{}
// Visit a parse tree produced by PlanParser#Term.
VisitTerm(ctx *TermContext) interface{}
@ -115,6 +64,72 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#Range.
VisitRange(ctx *RangeContext) interface{}
// Visit a parse tree produced by PlanParser#STIsValid.
VisitSTIsValid(ctx *STIsValidContext) interface{}
// Visit a parse tree produced by PlanParser#BitXor.
VisitBitXor(ctx *BitXorContext) interface{}
// Visit a parse tree produced by PlanParser#ElementFilter.
VisitElementFilter(ctx *ElementFilterContext) interface{}
// Visit a parse tree produced by PlanParser#BitAnd.
VisitBitAnd(ctx *BitAndContext) interface{}
// Visit a parse tree produced by PlanParser#STOverlaps.
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
// Visit a parse tree produced by PlanParser#JSONIdentifier.
VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{}
// Visit a parse tree produced by PlanParser#RandomSample.
VisitRandomSample(ctx *RandomSampleContext) interface{}
// Visit a parse tree produced by PlanParser#Parens.
VisitParens(ctx *ParensContext) interface{}
// Visit a parse tree produced by PlanParser#JSONContainsAll.
VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{}
// Visit a parse tree produced by PlanParser#LogicalOr.
VisitLogicalOr(ctx *LogicalOrContext) interface{}
// Visit a parse tree produced by PlanParser#MulDivMod.
VisitMulDivMod(ctx *MulDivModContext) interface{}
// Visit a parse tree produced by PlanParser#LogicalAnd.
VisitLogicalAnd(ctx *LogicalAndContext) interface{}
// Visit a parse tree produced by PlanParser#TemplateVariable.
VisitTemplateVariable(ctx *TemplateVariableContext) interface{}
// Visit a parse tree produced by PlanParser#TimestamptzCompareReverse.
VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{}
// Visit a parse tree produced by PlanParser#STDWithin.
VisitSTDWithin(ctx *STDWithinContext) interface{}
// Visit a parse tree produced by PlanParser#Call.
VisitCall(ctx *CallContext) interface{}
// Visit a parse tree produced by PlanParser#STCrosses.
VisitSTCrosses(ctx *STCrossesContext) interface{}
// Visit a parse tree produced by PlanParser#BitOr.
VisitBitOr(ctx *BitOrContext) interface{}
// Visit a parse tree produced by PlanParser#AddSub.
VisitAddSub(ctx *AddSubContext) interface{}
// Visit a parse tree produced by PlanParser#Relational.
VisitRelational(ctx *RelationalContext) interface{}
// Visit a parse tree produced by PlanParser#TextMatch.
VisitTextMatch(ctx *TextMatchContext) interface{}
// Visit a parse tree produced by PlanParser#STContains.
VisitSTContains(ctx *STContainsContext) interface{}
// Visit a parse tree produced by PlanParser#Unary.
VisitUnary(ctx *UnaryContext) interface{}
@ -127,30 +142,21 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#JSONContainsAny.
VisitJSONContainsAny(ctx *JSONContainsAnyContext) interface{}
// Visit a parse tree produced by PlanParser#STIsValid.
VisitSTIsValid(ctx *STIsValidContext) interface{}
// Visit a parse tree produced by PlanParser#BitXor.
VisitBitXor(ctx *BitXorContext) interface{}
// Visit a parse tree produced by PlanParser#Exists.
VisitExists(ctx *ExistsContext) interface{}
// Visit a parse tree produced by PlanParser#BitAnd.
VisitBitAnd(ctx *BitAndContext) interface{}
// Visit a parse tree produced by PlanParser#STEuqals.
VisitSTEuqals(ctx *STEuqalsContext) interface{}
// Visit a parse tree produced by PlanParser#IsNull.
VisitIsNull(ctx *IsNullContext) interface{}
// Visit a parse tree produced by PlanParser#StructSubField.
VisitStructSubField(ctx *StructSubFieldContext) interface{}
// Visit a parse tree produced by PlanParser#Power.
VisitPower(ctx *PowerContext) interface{}
// Visit a parse tree produced by PlanParser#STOverlaps.
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
// Visit a parse tree produced by PlanParser#textMatchOption.
VisitTextMatchOption(ctx *TextMatchOptionContext) interface{}
}

View File

@ -42,6 +42,8 @@ type ParserVisitor struct {
parser.BasePlanVisitor
schema *typeutil.SchemaHelper
args *ParserVisitorArgs
// currentStructArrayField stores the struct array field name when processing ElementFilter
currentStructArrayField string
}
func NewParserVisitor(schema *typeutil.SchemaHelper, args *ParserVisitorArgs) *ParserVisitor {
@ -658,6 +660,10 @@ func isRandomSampleExpr(expr *ExprWithType) bool {
return expr.expr.GetRandomSampleExpr() != nil
}
func isElementFilterExpr(expr *ExprWithType) bool {
return expr.expr.GetElementFilterExpr() != nil
}
const EPSILON = 1e-10
func (v *ParserVisitor) VisitRandomSample(ctx *parser.RandomSampleContext) interface{} {
@ -773,7 +779,47 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} {
}
}
func (v *ParserVisitor) getChildColumnInfo(identifier, child antlr.TerminalNode) (*planpb.ColumnInfo, error) {
func isValidStructSubField(tokenText string) bool {
return len(tokenText) >= 4 && tokenText[:2] == "$[" && tokenText[len(tokenText)-1] == ']'
}
func (v *ParserVisitor) getColumnInfoFromStructSubField(tokenText string) (*planpb.ColumnInfo, error) {
if !isValidStructSubField(tokenText) {
return nil, fmt.Errorf("invalid struct sub-field syntax: %s", tokenText)
}
// Remove "$[" prefix and "]" suffix
fieldName := tokenText[2 : len(tokenText)-1]
// Check if we're inside an ElementFilter context
if v.currentStructArrayField == "" {
return nil, fmt.Errorf("$[%s] syntax can only be used inside ElementFilter", fieldName)
}
// Construct full field name for struct array field
fullFieldName := v.currentStructArrayField + "[" + fieldName + "]"
// Get the struct array field info
field, err := v.schema.GetFieldFromName(fullFieldName)
if err != nil {
return nil, fmt.Errorf("array field not found: %s, error: %s", fullFieldName, err)
}
// In element-level context, data_type should be the element type
elementType := field.GetElementType()
return &planpb.ColumnInfo{
FieldId: field.FieldID,
DataType: elementType, // Use element type, not storage type
IsPrimaryKey: field.IsPrimaryKey,
IsAutoID: field.AutoID,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
ElementType: elementType,
Nullable: field.GetNullable(),
IsElementLevel: true, // Mark as element-level access
}, nil
}
func (v *ParserVisitor) getChildColumnInfo(identifier, child, structSubField antlr.TerminalNode) (*planpb.ColumnInfo, error) {
if identifier != nil {
childExpr, err := v.translateIdentifier(identifier.GetText())
if err != nil {
@ -782,6 +828,10 @@ func (v *ParserVisitor) getChildColumnInfo(identifier, child antlr.TerminalNode)
return toColumnInfo(childExpr), nil
}
if structSubField != nil {
return v.getColumnInfoFromStructSubField(structSubField.GetText())
}
return v.getColumnInfoFromJSONIdentifier(child.GetText())
}
@ -812,7 +862,7 @@ func (v *ParserVisitor) VisitCall(ctx *parser.CallContext) interface{} {
// VisitRange translates expr to range plan.
func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} {
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), ctx.StructSubFieldIdentifier())
if err != nil {
return err
}
@ -893,7 +943,7 @@ func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} {
// VisitReverseRange parses the expression like "1 > a > 0".
func (v *ParserVisitor) VisitReverseRange(ctx *parser.ReverseRangeContext) interface{} {
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), ctx.StructSubFieldIdentifier())
if err != nil {
return err
}
@ -1081,6 +1131,10 @@ func (v *ParserVisitor) VisitLogicalOr(ctx *parser.LogicalOrContext) interface{}
return errors.New("random sample expression cannot be used in logical and expression")
}
if isElementFilterExpr(leftExpr) {
return errors.New("element filter expression can only be the last expression in the logical or expression")
}
if !canBeExecuted(leftExpr) || !canBeExecuted(rightExpr) {
return errors.New("'or' can only be used between boolean expressions")
}
@ -1133,6 +1187,10 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface
return errors.New("random sample expression can only be the last expression in the logical and expression")
}
if isElementFilterExpr(leftExpr) {
return errors.New("element filter expression can only be the last expression in the logical and expression")
}
if !canBeExecuted(leftExpr) || !canBeExecuted(rightExpr) {
return errors.New("'and' can only be used between boolean expressions")
}
@ -1146,6 +1204,15 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface
RandomSampleExpr: randomSampleExpr,
},
}
} else if isElementFilterExpr(rightExpr) {
// Similar to RandomSampleExpr, extract doc-level predicate
elementFilterExpr := rightExpr.expr.GetElementFilterExpr()
elementFilterExpr.Predicate = leftExpr.expr
expr = &planpb.Expr{
Expr: &planpb.Expr_ElementFilterExpr{
ElementFilterExpr: elementFilterExpr,
},
}
} else {
expr = &planpb.Expr{
Expr: &planpb.Expr_BinaryExpr{
@ -1410,7 +1477,7 @@ func (v *ParserVisitor) VisitEmptyArray(ctx *parser.EmptyArrayContext) interface
}
func (v *ParserVisitor) VisitIsNotNull(ctx *parser.IsNotNullContext) interface{} {
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), nil)
if err != nil {
return err
}
@ -1450,7 +1517,7 @@ func (v *ParserVisitor) VisitIsNotNull(ctx *parser.IsNotNullContext) interface{}
}
func (v *ParserVisitor) VisitIsNull(ctx *parser.IsNullContext) interface{} {
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
column, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), nil)
if err != nil {
return err
}
@ -1653,7 +1720,7 @@ func (v *ParserVisitor) VisitJSONContainsAny(ctx *parser.JSONContainsAnyContext)
}
func (v *ParserVisitor) VisitArrayLength(ctx *parser.ArrayLengthContext) interface{} {
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier())
columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier(), nil)
if err != nil {
return err
}
@ -2182,3 +2249,90 @@ func validateAndExtractMinShouldMatch(minShouldMatchExpr interface{}) ([]*planpb
}
return nil, nil
}
// VisitElementFilter handles ElementFilter(structArrayField, elementExpr) syntax.
func (v *ParserVisitor) VisitElementFilter(ctx *parser.ElementFilterContext) interface{} {
// Check for nested ElementFilter - not allowed
if v.currentStructArrayField != "" {
return fmt.Errorf("nested ElementFilter is not supported, already inside ElementFilter for field: %s", v.currentStructArrayField)
}
// Get struct array field name (first parameter)
arrayFieldName := ctx.Identifier().GetText()
// Set current context for element expression parsing
v.currentStructArrayField = arrayFieldName
defer func() { v.currentStructArrayField = "" }()
elementExpr := ctx.Expr().Accept(v)
if err := getError(elementExpr); err != nil {
return fmt.Errorf("cannot parse element expression: %s, error: %s", ctx.Expr().GetText(), err)
}
exprWithType := getExpr(elementExpr)
if exprWithType == nil {
return fmt.Errorf("invalid element expression: %s", ctx.Expr().GetText())
}
// Build ElementFilterExpr proto
return &ExprWithType{
expr: &planpb.Expr{
Expr: &planpb.Expr_ElementFilterExpr{
ElementFilterExpr: &planpb.ElementFilterExpr{
ElementExpr: exprWithType.expr,
StructName: arrayFieldName,
},
},
},
dataType: schemapb.DataType_Bool,
}
}
// VisitStructSubField handles $[fieldName] syntax within ElementFilter.
func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) interface{} {
// Extract the field name from $[fieldName]
tokenText := ctx.StructSubFieldIdentifier().GetText()
if !isValidStructSubField(tokenText) {
return fmt.Errorf("invalid struct sub-field syntax: %s", tokenText)
}
// Remove "$[" prefix and "]" suffix
fieldName := tokenText[2 : len(tokenText)-1]
// Check if we're inside an ElementFilter context
if v.currentStructArrayField == "" {
return fmt.Errorf("$[%s] syntax can only be used inside ElementFilter", fieldName)
}
// Construct full field name for struct array field
fullFieldName := v.currentStructArrayField + "[" + fieldName + "]"
// Get the struct array field info
field, err := v.schema.GetFieldFromName(fullFieldName)
if err != nil {
return fmt.Errorf("array field not found: %s, error: %s", fullFieldName, err)
}
// In element-level context, data_type should be the element type
elementType := field.GetElementType()
return &ExprWithType{
expr: &planpb.Expr{
Expr: &planpb.Expr_ColumnExpr{
ColumnExpr: &planpb.ColumnExpr{
Info: &planpb.ColumnInfo{
FieldId: field.FieldID,
DataType: elementType, // Use element type, not storage type
IsPrimaryKey: field.IsPrimaryKey,
IsAutoID: field.AutoID,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
ElementType: elementType,
Nullable: field.GetNullable(),
IsElementLevel: true, // Mark as element-level access
},
},
},
},
dataType: elementType, // Expression evaluates to element type
nodeDependent: true,
}
}

View File

@ -48,11 +48,27 @@ func newTestSchema(EnableDynamicField bool) *schemapb.CollectionSchema {
ElementType: schemapb.DataType_VarChar,
})
structArrayField := &schemapb.StructArrayFieldSchema{
FieldID: 132, Name: "struct_array", Fields: []*schemapb.FieldSchema{
{
FieldID: 133, Name: "struct_array[sub_str]", IsPrimaryKey: false, Description: "sub struct array field for string",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_VarChar,
},
{
FieldID: 134, Name: "struct_array[sub_int]", IsPrimaryKey: false, Description: "sub struct array field for int",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
}
return &schemapb.CollectionSchema{
Name: "test",
Description: "schema for test used",
AutoID: true,
Fields: fields,
StructArrayFields: []*schemapb.StructArrayFieldSchema{structArrayField},
EnableDynamicField: EnableDynamicField,
}
}
@ -2487,3 +2503,57 @@ func TestExpr_GISFunctionsInvalidParameterTypes(t *testing.T) {
assertInvalidExpr(t, schema, expr)
}
}
func TestExpr_ElementFilter(t *testing.T) {
schema := newTestSchema(true)
helper, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
// Valid expressions
validExprs := []string{
`element_filter(struct_array, 2 > $[sub_int] > 1)`,
`element_filter(struct_array, $[sub_int] > 1)`,
`element_filter(struct_array, $[sub_int] == 100)`,
`element_filter(struct_array, $[sub_int] >= 0)`,
`element_filter(struct_array, $[sub_int] <= 1000)`,
`element_filter(struct_array, $[sub_int] != 0)`,
`element_filter(struct_array, $[sub_str] == "1")`,
`element_filter(struct_array, $[sub_str] != "")`,
`element_filter(struct_array, $[sub_str] == "1" || $[sub_int] > 1)`,
`element_filter(struct_array, $[sub_str] == "1" && $[sub_int] > 1)`,
`element_filter(struct_array, $[sub_int] > 0 && $[sub_int] < 100)`,
`element_filter(struct_array, ($[sub_int] > 0 && $[sub_int] < 100) || $[sub_str] == "default")`,
`element_filter(struct_array, !($[sub_int] < 0))`,
`Int64Field > 0 && element_filter(struct_array, $[sub_int] > 1)`,
}
for _, expr := range validExprs {
assertValidExpr(t, helper, expr)
}
// Invalid expressions
invalidExprs := []string{
`element_filter(struct_array, element_filter(struct_array, $[sub_int] > 1))`,
`element_filter(struct_array, $[sub_int] > 1 && element_filter(struct_array, $[sub_str] == "1"))`,
`$[sub_int] > 1`,
`Int64Field > 0 && $[sub_int] > 1`,
`element_filter(struct_array, $[non_existent_field] > 1)`,
`element_filter(non_existent_array, $[sub_int] > 1)`,
`element_filter(struct_array)`, // missing element expression
`element_filter()`, // missing all parameters
`element_filter(struct_array, $[sub_int] > 1) || element_filter(struct_array, $[sub_str] == "test")`,
`element_filter(struct_array, $[sub_int] > 1) && Int64Field > 0`,
}
for _, expr := range invalidExprs {
assertInvalidExpr(t, helper, expr)
}
}

View File

@ -150,6 +150,18 @@ func reduceAdvanceGroupBy(ctx context.Context, subSearchResultData []*schemapb.S
gpFieldBuilder.Add(groupByVal)
typeutil.AppendPKs(ret.Results.Ids, pk)
ret.Results.Scores = append(ret.Results.Scores, score)
// Handle ElementIndices if present
if subData.ElementIndices != nil {
if ret.Results.ElementIndices == nil {
ret.Results.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0, limit),
}
}
elemIdx := subData.ElementIndices.GetData()[innerIdx]
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
}
dataCount += 1
}
}
@ -308,6 +320,18 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
}
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
// Handle ElementIndices if present
if subResData.ElementIndices != nil {
if ret.Results.ElementIndices == nil {
ret.Results.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0, limit),
}
}
elemIdx := subResData.ElementIndices.GetData()[groupEntity.resultIdx]
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
}
gpFieldBuilder.Add(groupVal)
}
}
@ -436,6 +460,18 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
}
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
ret.Results.Scores = append(ret.Results.Scores, score)
// Handle ElementIndices if present
if subSearchResultData[subSearchIdx].ElementIndices != nil {
if ret.Results.ElementIndices == nil {
ret.Results.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0, limit),
}
}
elemIdx := subSearchResultData[subSearchIdx].ElementIndices.GetData()[resultDataIdx]
ret.Results.ElementIndices.Data = append(ret.Results.ElementIndices.Data, elemIdx)
}
cursors[subSearchIdx]++
}
if realTopK != -1 && realTopK != j {

View File

@ -421,10 +421,22 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "int vector index does not support metric type: "+metricType)
}
} else if typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) {
// TODO(SpadeA): adjust it when more metric types are supported. Especially, when different metric types
// are supported for different element types.
if !funcutil.SliceContain(indexparamcheck.EmbListMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector index does not support metric type: "+metricType)
if typeutil.IsDenseFloatVectorType(cit.fieldSchema.ElementType) {
if !funcutil.SliceContain(indexparamcheck.FloatVectorMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector with float element type does not support metric type: "+metricType)
}
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.ElementType) {
if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector with binary element type does not support metric type: "+metricType)
}
} else if typeutil.IsIntVectorType(cit.fieldSchema.ElementType) {
if !funcutil.SliceContain(indexparamcheck.IntVectorMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector with int element type does not support metric type: "+metricType)
}
} else {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector index does not support metric type: "+metricType)
}
}
}
}

View File

@ -1229,37 +1229,6 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
assert.NoError(t, err)
})
t.Run("metrics wrong for embedding list index", func(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "EmbListFloat",
IsPrimaryKey: false,
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
}
err := cit.parseIndexParams(context.TODO())
assert.True(t, strings.Contains(err.Error(), "array of vector index does not support metric type: L2"))
})
t.Run("metric type wrong", func(t *testing.T) {
cit := &createIndexTask{
Condition: nil,

View File

@ -45,6 +45,20 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
Topks: make([]int64, 0),
}
// Check element-level consistency: all results must have ElementIndices or none
hasElementIndices := searchResultData[0].ElementIndices != nil
for i, data := range searchResultData {
if (data.ElementIndices != nil) != hasElementIndices {
return nil, fmt.Errorf("inconsistent element-level flag in search results: result[0] has ElementIndices=%v, but result[%d] has ElementIndices=%v",
hasElementIndices, i, data.ElementIndices != nil)
}
}
if hasElementIndices {
ret.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0),
}
}
resultOffsets := make([][]int64, len(searchResultData))
for i := 0; i < len(searchResultData); i++ {
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
@ -76,6 +90,9 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
typeutil.AppendPKs(ret.Ids, id)
ret.Scores = append(ret.Scores, score)
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {
ret.ElementIndices.Data = append(ret.ElementIndices.Data, searchResultData[sel].ElementIndices.Data[idx])
}
idSet[id] = struct{}{}
j++
} else {
@ -127,6 +144,20 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
Topks: make([]int64, 0),
}
// Check element-level consistency: all results must have ElementIndices or none
hasElementIndices := searchResultData[0].ElementIndices != nil
for i, data := range searchResultData {
if (data.ElementIndices != nil) != hasElementIndices {
return nil, fmt.Errorf("inconsistent element-level flag in search results: result[0] has ElementIndices=%v, but result[%d] has ElementIndices=%v",
hasElementIndices, i, data.ElementIndices != nil)
}
}
if hasElementIndices {
ret.ElementIndices = &schemapb.LongArray{
Data: make([]int64, 0),
}
}
resultOffsets := make([][]int64, len(searchResultData))
groupByValIterator := make([]func(int) any, len(searchResultData))
for i := range searchResultData {
@ -180,6 +211,9 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
typeutil.AppendPKs(ret.Ids, id)
ret.Scores = append(ret.Scores, score)
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {
ret.ElementIndices.Data = append(ret.ElementIndices.Data, searchResultData[sel].ElementIndices.Data[idx])
}
gpFieldBuilder.Add(groupByVal)
groupByValueMap[groupByVal] += 1
idSet[id] = struct{}{}

View File

@ -2120,8 +2120,18 @@ func estimateLoadingResourceUsageOfSegment(schema *schemapb.CollectionSchema, lo
}
}
// per struct memory size, used to keep mapping between row id and element id
var structArrayOffsetsSize uint64
// PART 6: calculate size of struct array offsets
// The memory size is 4 * row_count + 4 * total_element_count
// We cannot easily get the element count, so we estimate it by the row count * 10
rowCount := uint64(loadInfo.GetNumOfRows())
for range len(schema.GetStructArrayFields()) {
structArrayOffsetsSize += 4*rowCount + 4*rowCount*10
}
return &ResourceUsage{
MemorySize: segMemoryLoadingSize + indexMemorySize,
MemorySize: segMemoryLoadingSize + indexMemorySize + structArrayOffsetsSize,
DiskSize: segDiskLoadingSize,
MmapFieldCount: mmapFieldCount,
FieldGpuMemorySize: fieldGpuMemorySize,

View File

@ -59,7 +59,21 @@ func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, elementType sch
}
} else if typeutil.IsArrayOfVectorType(dataType) {
if !CheckStrByValues(params, Metric, EmbListMetrics) {
return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], EmbListMetrics)
if typeutil.IsDenseFloatVectorType(elementType) {
if !CheckStrByValues(params, Metric, FloatVectorMetrics) {
return fmt.Errorf("metric type %s not found or not supported for array of vector with float element type, supported: %v", params[Metric], FloatVectorMetrics)
}
} else if typeutil.IsBinaryVectorType(elementType) {
if !CheckStrByValues(params, Metric, BinaryVectorMetrics) {
return fmt.Errorf("metric type %s not found or not supported for array of vector with binary element type, supported: %v", params[Metric], BinaryVectorMetrics)
}
} else if typeutil.IsIntVectorType(elementType) {
if !CheckStrByValues(params, Metric, IntVectorMetrics) {
return fmt.Errorf("metric type %s not found or not supported for array of vector with int element type, supported: %v", params[Metric], IntVectorMetrics)
}
} else {
return fmt.Errorf("metric type %s not found or not supported for array of vector, supported: %v", params[Metric], EmbListMetrics)
}
}
}

View File

@ -98,6 +98,7 @@ message ColumnInfo {
schema.DataType element_type = 7;
bool is_clustering_key = 8;
bool nullable = 9;
bool is_element_level = 10;
}
message ColumnExpr {
@ -245,6 +246,12 @@ message RandomSampleExpr {
Expr predicate = 2;
}
message ElementFilterExpr {
Expr element_expr = 1;
string struct_name = 2;
Expr predicate = 3;
}
message AlwaysTrueExpr {}
message Interval {
@ -285,6 +292,7 @@ message Expr {
RandomSampleExpr random_sample_expr = 16;
GISFunctionFilterExpr gisfunction_filter_expr = 17;
TimestamptzArithCompareExpr timestamptz_arith_compare_expr = 18;
ElementFilterExpr element_filter_expr = 19;
};
bool is_template = 20;
}

File diff suppressed because it is too large Load Diff