feat: support match operator family (#46518)

issue: https://github.com/milvus-io/milvus/issues/46517
ref: https://github.com/milvus-io/milvus/issues/42148

This PR supports match operator family with struct array and brute force
search only.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
- Core invariant: match operators only target struct-array element-level
predicates and assume callers provide a correct row_start so element
indices form a contiguous range; IArrayOffsets implementations convert
row-level bitmaps/rows (starting at row_start) into element-level
bitmaps or a contiguous element-offset vector used by brute-force
evaluation.

- New capability added: end-to-end support for MATCH_* semantics
(match_any, match_all, match_least, match_most, match_exact) — parser
(grammar + proto), planner (ParseMatchExprs), expr model
(expr::MatchExpr), compilation (Expr→PhyMatchFilterExpr), execution
(PhyMatchFilterExpr::Eval uses element offsets/bitmaps), and unit tests
(MatchExprTest + parser tests). Implementation currently works for
struct-array inputs and uses brute-force element counting via
RowBitsetToElementOffsets/RowBitsetToElementBitset.

- Logic removed or simplified and why: removed the ad-hoc
DocBitsetToElementOffsets helper and consolidated offset/bitset
derivation into IArrayOffsets::RowBitsetToElementOffsets and a
row_start-aware RowBitsetToElementBitset, and removed EvalCtx overloads
that embedded ExprSet (now EvalCtx(exec_ctx, offset_input)). This
centralizes array-layout logic in ArrayOffsets and removes duplicated
offset conversion and EvalCtx variants that were redundant for
element-level evaluation.

- No data loss / no behavior regression: persistent formats are
unchanged (no proto storage or on-disk layout changed); callers were
updated to supply row_start and now route through the centralized
ArrayOffsets APIs which still use the authoritative
row_to_element_start_ mapping, preserving exact element index mappings.
Eval logic changes are limited to in-memory plumbing (how
offsets/bitmaps are produced and how EvalCtx is constructed); expression
evaluation still invokes exprs_->Eval where needed, so existing behavior
and stored data remain intact.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
Signed-off-by: SpadeA-Tang <tangchenjie1210@gmail.com>
This commit is contained in:
Spade A 2025-12-29 11:03:26 +08:00 committed by GitHub
parent 0d70d2b98c
commit 0114bd1dc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 4102 additions and 1898 deletions

View File

@ -46,25 +46,73 @@ ArrayOffsetsSealed::ElementIDRangeOfRow(int32_t row_id) const {
std::pair<TargetBitmap, TargetBitmap>
ArrayOffsetsSealed::RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const {
int64_t row_count = GetRowCount();
int64_t element_count = GetTotalElementCount();
const TargetBitmapView& valid_row_bitset,
int64_t row_start) const {
int64_t row_count = row_bitset.size();
AssertInfo(row_start >= 0 && row_start + row_count <= GetRowCount(),
"row range out of bounds: row_start={}, row_count={}, "
"total_rows={}",
row_start,
row_count,
GetRowCount());
int64_t element_start = row_to_element_start_[row_start];
int64_t element_end = row_to_element_start_[row_start + row_count];
int64_t element_count = element_end - element_start;
TargetBitmap element_bitset(element_count);
TargetBitmap valid_element_bitset(element_count);
for (int64_t row_id = 0; row_id < row_count; ++row_id) {
int64_t start = row_to_element_start_[row_id];
int64_t end = row_to_element_start_[row_id + 1];
for (int64_t i = 0; i < row_count; ++i) {
int64_t row_id = row_start + i;
int64_t start = row_to_element_start_[row_id] - element_start;
int64_t end = row_to_element_start_[row_id + 1] - element_start;
if (start < end) {
element_bitset.set(start, end - start, row_bitset[row_id]);
valid_element_bitset.set(
start, end - start, valid_row_bitset[row_id]);
element_bitset.set(start, end - start, row_bitset[i]);
valid_element_bitset.set(start, end - start, valid_row_bitset[i]);
}
}
return {std::move(element_bitset), std::move(valid_element_bitset)};
}
FixedVector<int32_t>
ArrayOffsetsSealed::RowBitsetToElementOffsets(
const TargetBitmapView& row_bitset, int64_t row_start) const {
int64_t row_count = row_bitset.size();
AssertInfo(row_start >= 0 && row_start + row_count <= GetRowCount(),
"row range out of bounds: row_start={}, row_count={}, "
"total_rows={}",
row_start,
row_count,
GetRowCount());
int64_t selected_rows = row_bitset.count();
FixedVector<int32_t> element_offsets;
if (selected_rows == 0) {
return element_offsets;
}
int64_t avg_elem_per_row =
static_cast<int64_t>(element_row_ids_.size()) /
(static_cast<int64_t>(row_to_element_start_.size()) - 1);
element_offsets.reserve(selected_rows * avg_elem_per_row);
for (int64_t i = 0; i < row_count; ++i) {
if (row_bitset[i]) {
int64_t row_id = row_start + i;
int32_t first_elem = row_to_element_start_[row_id];
int32_t last_elem = row_to_element_start_[row_id + 1];
for (int32_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
element_offsets.push_back(elem_id);
}
}
}
return element_offsets;
}
std::shared_ptr<ArrayOffsetsSealed>
ArrayOffsetsSealed::BuildFromSegment(const void* segment,
const FieldMeta& field_meta) {
@ -193,23 +241,73 @@ ArrayOffsetsGrowing::ElementIDRangeOfRow(int32_t row_id) const {
std::pair<TargetBitmap, TargetBitmap>
ArrayOffsetsGrowing::RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const {
const TargetBitmapView& valid_row_bitset,
int64_t row_start) const {
std::shared_lock lock(mutex_);
int64_t element_count = element_row_ids_.size();
int64_t row_count = row_bitset.size();
AssertInfo(row_start >= 0 && row_start + row_count <= committed_row_count_,
"row range out of bounds: row_start={}, row_count={}, "
"committed_rows={}",
row_start,
row_count,
committed_row_count_);
int64_t element_start = row_to_element_start_[row_start];
int64_t element_end = row_to_element_start_[row_start + row_count];
int64_t element_count = element_end - element_start;
TargetBitmap element_bitset(element_count);
TargetBitmap valid_element_bitset(element_count);
// Direct access to element_row_ids_, no virtual function calls
for (size_t elem_id = 0; elem_id < element_row_ids_.size(); ++elem_id) {
for (int64_t elem_id = element_start; elem_id < element_end; ++elem_id) {
auto row_id = element_row_ids_[elem_id];
element_bitset[elem_id] = row_bitset[row_id];
valid_element_bitset[elem_id] = valid_row_bitset[row_id];
int64_t bitset_idx = row_id - row_start;
element_bitset[elem_id - element_start] = row_bitset[bitset_idx];
valid_element_bitset[elem_id - element_start] =
valid_row_bitset[bitset_idx];
}
return {std::move(element_bitset), std::move(valid_element_bitset)};
}
FixedVector<int32_t>
ArrayOffsetsGrowing::RowBitsetToElementOffsets(
const TargetBitmapView& row_bitset, int64_t row_start) const {
std::shared_lock lock(mutex_);
int64_t row_count = row_bitset.size();
AssertInfo(row_start >= 0 && row_start + row_count <= committed_row_count_,
"row range out of bounds: row_start={}, row_count={}, "
"committed_rows={}",
row_start,
row_count,
committed_row_count_);
int64_t selected_rows = row_bitset.count();
FixedVector<int32_t> element_offsets;
if (selected_rows == 0) {
return element_offsets;
}
int64_t avg_elem_per_row =
static_cast<int64_t>(element_row_ids_.size()) /
(static_cast<int64_t>(row_to_element_start_.size()) - 1);
element_offsets.reserve(selected_rows * avg_elem_per_row);
for (int64_t i = 0; i < row_count; ++i) {
if (row_bitset[i]) {
int64_t row_id = row_start + i;
int32_t first_elem = row_to_element_start_[row_id];
int32_t last_elem = row_to_element_start_[row_id + 1];
for (int32_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
element_offsets.push_back(elem_id);
}
}
}
return element_offsets;
}
void
ArrayOffsetsGrowing::Insert(int64_t row_id_start,
const int32_t* array_lengths,

View File

@ -50,10 +50,18 @@ class IArrayOffsets {
ElementIDRangeOfRow(int32_t row_id) const = 0;
// Convert row-level bitsets to element-level bitsets
// row_start: starting row index (0-based)
// row_bitset.size(): number of rows to process
virtual std::pair<TargetBitmap, TargetBitmap>
RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const = 0;
RowBitsetToElementBitset(const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset,
int64_t row_start) const = 0;
// Convert row-level bitset to element offsets
// Returns element IDs for all rows where row_bitset[row_id] is true
virtual FixedVector<int32_t>
RowBitsetToElementOffsets(const TargetBitmapView& row_bitset,
int64_t row_start) const = 0;
};
class ArrayOffsetsSealed : public IArrayOffsets {
@ -93,9 +101,13 @@ class ArrayOffsetsSealed : public IArrayOffsets {
ElementIDRangeOfRow(int32_t row_id) const override;
std::pair<TargetBitmap, TargetBitmap>
RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const override;
RowBitsetToElementBitset(const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset,
int64_t row_start) const override;
FixedVector<int32_t>
RowBitsetToElementOffsets(const TargetBitmapView& row_bitset,
int64_t row_start) const override;
static std::shared_ptr<ArrayOffsetsSealed>
BuildFromSegment(const void* segment, const FieldMeta& field_meta);
@ -132,9 +144,13 @@ class ArrayOffsetsGrowing : public IArrayOffsets {
ElementIDRangeOfRow(int32_t row_id) const override;
std::pair<TargetBitmap, TargetBitmap>
RowBitsetToElementBitset(
const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset) const override;
RowBitsetToElementBitset(const TargetBitmapView& row_bitset,
const TargetBitmapView& valid_row_bitset,
int64_t row_start) const override;
FixedVector<int32_t>
RowBitsetToElementOffsets(const TargetBitmapView& row_bitset,
int64_t row_start) const override;
private:
struct PendingRow {

View File

@ -114,7 +114,7 @@ TEST_F(ArrayOffsetsTest, SealedRowBitsetToElementBitset) {
valid_row_bitset.size());
auto [elem_bitset, valid_elem_bitset] =
offsets.RowBitsetToElementBitset(row_view, valid_view);
offsets.RowBitsetToElementBitset(row_view, valid_view, 0);
EXPECT_EQ(elem_bitset.size(), 6);
// Elements of row 0 (elem 0, 1) should be true
@ -310,7 +310,7 @@ TEST_F(ArrayOffsetsTest, GrowingRowBitsetToElementBitset) {
valid_row_bitset.size());
auto [elem_bitset, valid_elem_bitset] =
offsets.RowBitsetToElementBitset(row_view, valid_view);
offsets.RowBitsetToElementBitset(row_view, valid_view, 0);
EXPECT_EQ(elem_bitset.size(), 6);
EXPECT_TRUE(elem_bitset[0]);

View File

@ -79,7 +79,7 @@ ElementFilterIterator::FetchAndFilterBatch() {
}
// Step 2: Batch evaluate element-level expression
exec::EvalCtx eval_ctx(exec_context_, expr_set_, &element_ids_buffer_);
exec::EvalCtx eval_ctx(exec_context_, &element_ids_buffer_);
std::vector<VectorPtr> results;
// Evaluate the expression set (should contain only element_expr)

View File

@ -27,23 +27,12 @@
namespace milvus {
namespace exec {
class ExprSet;
using OffsetVector = FixedVector<int32_t>;
class EvalCtx {
public:
EvalCtx(ExecContext* exec_ctx,
ExprSet* expr_set,
OffsetVector* offset_input)
: exec_ctx_(exec_ctx),
expr_set_(expr_set),
offset_input_(offset_input) {
EvalCtx(ExecContext* exec_ctx, OffsetVector* offset_input)
: exec_ctx_(exec_ctx), offset_input_(offset_input) {
assert(exec_ctx_ != nullptr);
assert(expr_set_ != nullptr);
}
explicit EvalCtx(ExecContext* exec_ctx, ExprSet* expr_set)
: exec_ctx_(exec_ctx), expr_set_(expr_set) {
}
explicit EvalCtx(ExecContext* exec_ctx) : exec_ctx_(exec_ctx) {
@ -96,7 +85,6 @@ class EvalCtx {
private:
ExecContext* exec_ctx_ = nullptr;
ExprSet* expr_set_ = nullptr;
// we may accept offsets array as input and do expr filtering on these data
OffsetVector* offset_input_ = nullptr;
bool input_no_nulls_ = false;

View File

@ -34,7 +34,11 @@ PhyExistsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
context.set_apply_valid_data_after_flip(false);
auto input = context.get_offset_input();
SetHasOffsetInput((input != nullptr));
switch (expr_->column_.data_type_) {
auto data_type = expr_->column_.data_type_;
if (expr_->column_.element_level_) {
data_type = expr_->column_.element_type_;
}
switch (data_type) {
case DataType::JSON: {
if (SegmentExpr::CanUseIndex() && !has_offset_input_) {
result = EvalJsonExistsForIndex();
@ -44,9 +48,7 @@ PhyExistsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
break;
}
default:
ThrowInfo(DataTypeInvalid,
"unsupported data type: {}",
expr_->column_.data_type_);
ThrowInfo(DataTypeInvalid, "unsupported data type: {}", data_type);
}
}

View File

@ -30,6 +30,7 @@
#include "exec/expression/JsonContainsExpr.h"
#include "exec/expression/LogicalBinaryExpr.h"
#include "exec/expression/LogicalUnaryExpr.h"
#include "exec/expression/MatchExpr.h"
#include "exec/expression/NullExpr.h"
#include "exec/expression/TermExpr.h"
#include "exec/expression/UnaryExpr.h"
@ -294,6 +295,17 @@ CompileExpression(const expr::TypedExprPtr& expr,
context->get_active_count(),
context->query_config()->get_expr_batch_size(),
context->get_consistency_level());
} else if (auto match_expr =
std::dynamic_pointer_cast<const milvus::expr::MatchExpr>(
expr)) {
result = std::make_shared<PhyMatchFilterExpr>(
compiled_inputs,
match_expr,
"PhyMatchFilterExpr",
op_ctx,
context->get_segment(),
context->get_active_count(),
context->query_config()->get_expr_batch_size());
} else {
ThrowInfo(ExprInvalid, "unsupport expr: ", expr->ToString());
}

View File

@ -0,0 +1,258 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "MatchExpr.h"
#include <numeric>
#include <utility>
#include "common/Tracer.h"
#include "common/Types.h"
namespace milvus {
namespace exec {
using MatchType = milvus::expr::MatchType;
template <MatchType match_type, bool all_valid>
void
ProcessMatchRows(int64_t row_count,
const IArrayOffsets* array_offsets,
const TargetBitmapView& match_bitset,
const TargetBitmapView& valid_bitset,
TargetBitmapView& result_bitset,
int64_t threshold) {
for (int64_t i = 0; i < row_count; ++i) {
auto [first_elem, last_elem] = array_offsets->ElementIDRangeOfRow(i);
int64_t hit_count = 0;
int64_t element_count = last_elem - first_elem;
bool early_fail = false;
if constexpr (all_valid) {
for (auto j = first_elem; j < last_elem; ++j) {
bool matched = match_bitset[j];
if (matched) {
++hit_count;
}
if constexpr (match_type == MatchType::MatchAny) {
if (hit_count > 0) {
break;
}
} else if constexpr (match_type == MatchType::MatchAll) {
if (!matched) {
early_fail = true;
break;
}
} else if constexpr (match_type == MatchType::MatchLeast) {
if (hit_count >= threshold) {
break;
}
} else if constexpr (match_type == MatchType::MatchMost ||
match_type == MatchType::MatchExact) {
if (hit_count > threshold) {
early_fail = true;
break;
}
}
}
} else {
element_count = 0;
for (auto j = first_elem; j < last_elem; ++j) {
if (!valid_bitset[j]) {
continue;
}
++element_count;
bool matched = match_bitset[j];
if (matched) {
++hit_count;
}
if constexpr (match_type == MatchType::MatchAny) {
if (hit_count > 0) {
break;
}
} else if constexpr (match_type == MatchType::MatchAll) {
if (!matched) {
early_fail = true;
break;
}
} else if constexpr (match_type == MatchType::MatchLeast) {
if (hit_count >= threshold) {
break;
}
} else if constexpr (match_type == MatchType::MatchMost ||
match_type == MatchType::MatchExact) {
if (hit_count > threshold) {
early_fail = true;
break;
}
}
}
}
if (early_fail) {
continue;
}
bool is_match = false;
if constexpr (match_type == MatchType::MatchAny) {
is_match = hit_count > 0;
} else if constexpr (match_type == MatchType::MatchAll) {
is_match = hit_count == element_count;
} else if constexpr (match_type == MatchType::MatchLeast) {
is_match = hit_count >= threshold;
} else if constexpr (match_type == MatchType::MatchMost) {
is_match = hit_count <= threshold;
} else if constexpr (match_type == MatchType::MatchExact) {
is_match = hit_count == threshold;
}
if (is_match) {
result_bitset[i] = true;
}
}
}
template <MatchType match_type>
void
DispatchByValidity(bool all_valid,
int64_t row_count,
const IArrayOffsets* array_offsets,
const TargetBitmapView& match_bitset,
const TargetBitmapView& valid_bitset,
TargetBitmapView& result_bitset,
int64_t threshold) {
if (all_valid) {
ProcessMatchRows<match_type, true>(row_count,
array_offsets,
match_bitset,
valid_bitset,
result_bitset,
threshold);
} else {
ProcessMatchRows<match_type, false>(row_count,
array_offsets,
match_bitset,
valid_bitset,
result_bitset,
threshold);
}
}
void
PhyMatchFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
tracer::AutoSpan span("PhyMatchFilterExpr::Eval", tracer::GetRootSpan());
auto input = context.get_offset_input();
AssertInfo(input == nullptr,
"Offset input in match filter expr is not implemented now");
auto schema = segment_->get_schema();
auto field_meta =
schema.GetFirstArrayFieldInStruct(expr_->get_struct_name());
auto array_offsets = segment_->GetArrayOffsets(field_meta.get_id());
AssertInfo(array_offsets != nullptr, "Array offsets not available");
int64_t row_count =
context.get_exec_context()->get_query_context()->get_active_count();
result = std::make_shared<ColumnVector>(TargetBitmap(row_count, false),
TargetBitmap(row_count, true));
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(result);
AssertInfo(col_vec != nullptr, "Result should be ColumnVector");
AssertInfo(col_vec->IsBitmap(), "Result should be bitmap");
auto col_vec_size = col_vec->size();
TargetBitmapView bitset_view(col_vec->GetRawData(), col_vec_size);
auto [total_elements, _] = array_offsets->ElementIDRangeOfRow(row_count);
FixedVector<int32_t> element_offsets(total_elements);
std::iota(element_offsets.begin(), element_offsets.end(), 0);
EvalCtx eval_ctx(context.get_exec_context(), &element_offsets);
VectorPtr match_result;
// TODO(SpadeA): can be executed in batch
inputs_[0]->Eval(eval_ctx, match_result);
auto match_result_col_vec =
std::dynamic_pointer_cast<ColumnVector>(match_result);
AssertInfo(match_result_col_vec != nullptr,
"Match result should be ColumnVector");
AssertInfo(match_result_col_vec->IsBitmap(),
"Match result should be bitmap");
TargetBitmapView match_result_bitset_view(
match_result_col_vec->GetRawData(), match_result_col_vec->size());
TargetBitmapView match_result_valid_view(
match_result_col_vec->GetValidRawData(), match_result_col_vec->size());
bool all_valid = match_result_valid_view.all();
auto match_type = expr_->get_match_type();
int64_t threshold = expr_->get_count();
switch (match_type) {
case MatchType::MatchAny:
DispatchByValidity<MatchType::MatchAny>(all_valid,
row_count,
array_offsets.get(),
match_result_bitset_view,
match_result_valid_view,
bitset_view,
threshold);
break;
case MatchType::MatchAll:
DispatchByValidity<MatchType::MatchAll>(all_valid,
row_count,
array_offsets.get(),
match_result_bitset_view,
match_result_valid_view,
bitset_view,
threshold);
break;
case MatchType::MatchLeast:
DispatchByValidity<MatchType::MatchLeast>(all_valid,
row_count,
array_offsets.get(),
match_result_bitset_view,
match_result_valid_view,
bitset_view,
threshold);
break;
case MatchType::MatchMost:
DispatchByValidity<MatchType::MatchMost>(all_valid,
row_count,
array_offsets.get(),
match_result_bitset_view,
match_result_valid_view,
bitset_view,
threshold);
break;
case MatchType::MatchExact:
DispatchByValidity<MatchType::MatchExact>(all_valid,
row_count,
array_offsets.get(),
match_result_bitset_view,
match_result_valid_view,
bitset_view,
threshold);
break;
default:
ThrowInfo(OpTypeInvalid,
"Unsupported match type: {}",
static_cast<int>(match_type));
}
}
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,86 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fmt/core.h>
#include "common/EasyAssert.h"
#include "common/OpContext.h"
#include "common/Types.h"
#include "common/Vector.h"
#include "exec/expression/Expr.h"
#include "segcore/SegmentInterface.h"
namespace milvus {
namespace exec {
class PhyMatchFilterExpr : public Expr {
public:
PhyMatchFilterExpr(
const std::vector<std::shared_ptr<Expr>>& input,
const std::shared_ptr<const milvus::expr::MatchExpr>& expr,
const std::string& name,
milvus::OpContext* op_ctx,
const segcore::SegmentInternalInterface* segment,
int64_t active_count,
int64_t batch_size)
: Expr(DataType::BOOL, std::move(input), name, op_ctx),
expr_(expr),
segment_(segment),
active_count_(active_count),
batch_size_(batch_size) {
}
void
Eval(EvalCtx& context, VectorPtr& result) override;
void
MoveCursor() override {
if (!has_offset_input_) {
int64_t real_batch_size =
current_pos_ + batch_size_ >= active_count_
? active_count_ - current_pos_
: batch_size_;
current_pos_ += real_batch_size;
}
}
std::string
ToString() const override {
return fmt::format("{}", expr_->ToString());
}
bool
IsSource() const override {
return true;
}
std::optional<milvus::expr::ColumnInfo>
GetColumnInfo() const override {
return std::nullopt;
}
private:
std::shared_ptr<const milvus::expr::MatchExpr> expr_;
const segcore::SegmentInternalInterface* segment_;
int64_t active_count_;
int64_t current_pos_{0};
int64_t batch_size_;
};
} // namespace exec
} // namespace milvus

View File

@ -0,0 +1,420 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License
#include <gtest/gtest.h>
#include <boost/format.hpp>
#include <iostream>
#include <random>
#include "common/Schema.h"
#include "pb/plan.pb.h"
#include "query/Plan.h"
#include "segcore/SegmentGrowingImpl.h"
#include "test_utils/DataGen.h"
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
class MatchExprTest : public ::testing::Test {
protected:
void
SetUp() override {
// Create schema with struct array sub-fields
schema_ = std::make_shared<Schema>();
vec_fid_ = schema_->AddDebugField(
"vec", DataType::VECTOR_FLOAT, 4, knowhere::metric::L2);
int64_fid_ = schema_->AddDebugField("id", DataType::INT64);
schema_->set_primary_field_id(int64_fid_);
sub_str_fid_ = schema_->AddDebugArrayField(
"struct_array[sub_str]", DataType::VARCHAR, false);
sub_int_fid_ = schema_->AddDebugArrayField(
"struct_array[sub_int]", DataType::INT32, false);
// Generate test data
GenerateTestData();
// Create and populate segment
seg_ = CreateGrowingSegment(schema_, empty_index_meta);
seg_->PreInsert(N_);
seg_->Insert(
0, N_, row_ids_.data(), timestamps_.data(), insert_data_.get());
}
void
GenerateTestData() {
std::default_random_engine rng(42);
std::vector<std::string> str_choices = {"aaa", "bbb", "ccc"};
std::uniform_int_distribution<> str_dist(0, 2);
std::uniform_int_distribution<> int_dist(50, 150);
insert_data_ = std::make_unique<InsertRecordProto>();
// Generate vector field
std::vector<float> vec_data(N_ * 4);
std::normal_distribution<float> vec_dist(0, 1);
for (auto& v : vec_data) {
v = vec_dist(rng);
}
auto vec_array = CreateDataArrayFrom(
vec_data.data(), nullptr, N_, schema_->operator[](vec_fid_));
insert_data_->mutable_fields_data()->AddAllocated(vec_array.release());
// Generate id field
std::vector<int64_t> id_data(N_);
for (size_t i = 0; i < N_; ++i) {
id_data[i] = i;
}
auto id_array = CreateDataArrayFrom(
id_data.data(), nullptr, N_, schema_->operator[](int64_fid_));
insert_data_->mutable_fields_data()->AddAllocated(id_array.release());
// Generate struct_array[sub_str]
sub_str_data_.resize(N_);
for (size_t i = 0; i < N_; ++i) {
for (int j = 0; j < array_len_; ++j) {
sub_str_data_[i].mutable_string_data()->add_data(
str_choices[str_dist(rng)]);
}
}
auto sub_str_array =
CreateDataArrayFrom(sub_str_data_.data(),
nullptr,
N_,
schema_->operator[](sub_str_fid_));
insert_data_->mutable_fields_data()->AddAllocated(
sub_str_array.release());
// Generate struct_array[sub_int]
sub_int_data_.resize(N_);
for (size_t i = 0; i < N_; ++i) {
for (int j = 0; j < array_len_; ++j) {
sub_int_data_[i].mutable_int_data()->add_data(int_dist(rng));
}
}
auto sub_int_array =
CreateDataArrayFrom(sub_int_data_.data(),
nullptr,
N_,
schema_->operator[](sub_int_fid_));
insert_data_->mutable_fields_data()->AddAllocated(
sub_int_array.release());
insert_data_->set_num_rows(N_);
// Generate row_ids and timestamps
row_ids_.resize(N_);
timestamps_.resize(N_);
for (size_t i = 0; i < N_; ++i) {
row_ids_[i] = i;
timestamps_[i] = i;
}
}
// Count elements matching: sub_str == "aaa" && sub_int > 100
int
CountMatchingElements(int64_t row_idx) const {
int count = 0;
const auto& str_field = sub_str_data_[row_idx];
const auto& int_field = sub_int_data_[row_idx];
for (int j = 0; j < array_len_; ++j) {
bool str_match = (str_field.string_data().data(j) == "aaa");
bool int_match = (int_field.int_data().data(j) > 100);
if (str_match && int_match) {
++count;
}
}
return count;
}
// Create plan with specified match type and count
std::string
CreatePlanText(const std::string& match_type, int64_t count) {
return boost::str(boost::format(R"(vector_anns: <
field_id: %1%
predicates: <
match_expr: <
struct_name: "struct_array"
match_type: %4%
count: %5%
predicate: <
binary_expr: <
op: LogicalAnd
left: <
unary_range_expr: <
column_info: <
field_id: %2%
data_type: Array
element_type: VarChar
nested_path: "sub_str"
is_element_level: true
>
op: Equal
value: <
string_val: "aaa"
>
>
>
right: <
unary_range_expr: <
column_info: <
field_id: %3%
data_type: Array
element_type: Int32
nested_path: "sub_int"
is_element_level: true
>
op: GreaterThan
value: <
int64_val: 100
>
>
>
>
>
>
>
query_info: <
topk: 10
round_decimal: 3
metric_type: "L2"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)") % vec_fid_.get() %
sub_str_fid_.get() % sub_int_fid_.get() % match_type %
count);
}
// Execute search and return results
std::unique_ptr<SearchResult>
ExecuteSearch(const std::string& raw_plan) {
proto::plan::PlanNode plan_node;
auto ok =
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
EXPECT_TRUE(ok) << "Failed to parse plan";
auto plan = CreateSearchPlanFromPlanNode(schema_, plan_node);
EXPECT_NE(plan, nullptr);
auto num_queries = 1;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 4, 1024);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
return seg_->Search(plan.get(), ph_group.get(), 1L << 63);
}
// Verify results based on match type
using VerifyFunc = std::function<bool(
int match_count, int element_count, int64_t threshold)>;
void
VerifyResults(const SearchResult* result,
const std::string& match_type_name,
int64_t threshold,
VerifyFunc verify_func) {
std::cout << "=== " << match_type_name << " Results ===" << std::endl;
std::cout << "total_nq: " << result->total_nq_ << std::endl;
std::cout << "unity_topK: " << result->unity_topK_ << std::endl;
std::cout << "num_results: " << result->seg_offsets_.size()
<< std::endl;
for (int64_t i = 0; i < result->total_nq_; ++i) {
std::cout << "Query " << i << ":" << std::endl;
for (int64_t k = 0; k < result->unity_topK_; ++k) {
int64_t idx = i * result->unity_topK_ + k;
auto offset = result->seg_offsets_[idx];
auto distance = result->distances_[idx];
std::cout << " [" << k << "] offset=" << offset
<< ", distance=" << distance;
if (offset >= 0 && offset < static_cast<int64_t>(N_)) {
// Print sub_str array
std::cout << ", sub_str=[";
const auto& str_field = sub_str_data_[offset];
for (int j = 0; j < str_field.string_data().data_size();
++j) {
if (j > 0)
std::cout << ",";
std::cout << str_field.string_data().data(j);
}
std::cout << "]";
// Print sub_int array
std::cout << ", sub_int=[";
const auto& int_field = sub_int_data_[offset];
for (int j = 0; j < int_field.int_data().data_size(); ++j) {
if (j > 0)
std::cout << ",";
std::cout << int_field.int_data().data(j);
}
std::cout << "]";
// Print match_count and verify
int match_count = CountMatchingElements(offset);
bool expected =
verify_func(match_count, array_len_, threshold);
std::cout << ", match_count=" << match_count;
EXPECT_TRUE(expected)
<< match_type_name << " failed for row " << offset
<< ": match_count=" << match_count
<< ", element_count=" << array_len_
<< ", threshold=" << threshold;
}
std::cout << std::endl;
}
}
std::cout << "==============================" << std::endl;
}
// Member variables
std::shared_ptr<Schema> schema_;
FieldId vec_fid_;
FieldId int64_fid_;
FieldId sub_str_fid_;
FieldId sub_int_fid_;
std::unique_ptr<InsertRecordProto> insert_data_;
std::vector<milvus::proto::schema::ScalarField> sub_str_data_;
std::vector<milvus::proto::schema::ScalarField> sub_int_data_;
std::vector<idx_t> row_ids_;
std::vector<Timestamp> timestamps_;
SegmentGrowingPtr seg_;
static constexpr size_t N_ = 1000;
static constexpr int array_len_ = 5;
};
TEST_F(MatchExprTest, MatchAny) {
auto raw_plan = CreatePlanText("MatchAny", 0);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchAny",
0,
[](int match_count, int /*element_count*/, int64_t /*threshold*/) {
// MatchAny: at least one element matches
return match_count > 0;
});
}
TEST_F(MatchExprTest, MatchAll) {
auto raw_plan = CreatePlanText("MatchAll", 0);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchAll",
0,
[](int match_count, int element_count, int64_t /*threshold*/) {
// MatchAll: all elements must match
return match_count == element_count;
});
}
TEST_F(MatchExprTest, MatchLeast) {
const int64_t threshold = 3;
auto raw_plan = CreatePlanText("MatchLeast", threshold);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchLeast(3)",
threshold,
[](int match_count, int /*element_count*/, int64_t threshold) {
// MatchLeast: at least N elements match
return match_count >= threshold;
});
}
TEST_F(MatchExprTest, MatchMost) {
const int64_t threshold = 2;
auto raw_plan = CreatePlanText("MatchMost", threshold);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchMost(2)",
threshold,
[](int match_count, int /*element_count*/, int64_t threshold) {
// MatchMost: at most N elements match
return match_count <= threshold;
});
}
TEST_F(MatchExprTest, MatchExact) {
const int64_t threshold = 2;
auto raw_plan = CreatePlanText("MatchExact", threshold);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchExact(2)",
threshold,
[](int match_count, int /*element_count*/, int64_t threshold) {
// MatchExact: exactly N elements match
return match_count == threshold;
});
}
// Edge case: MatchLeast with threshold = 1 (equivalent to MatchAny)
TEST_F(MatchExprTest, MatchLeastOne) {
const int64_t threshold = 1;
auto raw_plan = CreatePlanText("MatchLeast", threshold);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchLeast(1)",
threshold,
[](int match_count, int /*element_count*/, int64_t threshold) {
return match_count >= threshold;
});
}
// Edge case: MatchMost with threshold = 0 (no elements should match)
TEST_F(MatchExprTest, MatchMostZero) {
const int64_t threshold = 0;
auto raw_plan = CreatePlanText("MatchMost", threshold);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchMost(0)",
threshold,
[](int match_count, int /*element_count*/, int64_t threshold) {
return match_count <= threshold;
});
}
// Edge case: MatchExact with threshold = 0 (no elements should match)
TEST_F(MatchExprTest, MatchExactZero) {
const int64_t threshold = 0;
auto raw_plan = CreatePlanText("MatchExact", threshold);
auto result = ExecuteSearch(raw_plan);
VerifyResults(
result.get(),
"MatchExact(0)",
threshold,
[](int match_count, int /*element_count*/, int64_t threshold) {
return match_count == threshold;
});
}

View File

@ -31,7 +31,11 @@ PhyNullExpr::Eval(EvalCtx& context, VectorPtr& result) {
static_cast<int>(expr_->column_.data_type_));
auto input = context.get_offset_input();
switch (expr_->column_.data_type_) {
auto data_type = expr_->column_.data_type_;
if (expr_->column_.element_level_) {
data_type = expr_->column_.element_type_;
}
switch (data_type) {
case DataType::BOOL: {
result = ExecVisitorImpl<bool>(input);
break;

View File

@ -88,7 +88,7 @@ PhyElementFilterBitsNode::GetOutput() {
// Step 3: Convert doc bitset to element offsets
FixedVector<int32_t> element_offsets =
DocBitsetToElementOffsets(doc_bitset);
array_offsets->RowBitsetToElementOffsets(doc_bitset, 0);
// Step 4: Evaluate element expression
auto [expr_result, valid_expr_result] =
@ -122,38 +122,6 @@ PhyElementFilterBitsNode::GetOutput() {
return std::make_shared<RowVector>(col_res);
}
FixedVector<int32_t>
PhyElementFilterBitsNode::DocBitsetToElementOffsets(
const TargetBitmapView& doc_bitset) {
auto array_offsets = query_context_->get_array_offsets();
AssertInfo(array_offsets != nullptr, "Array offsets not available");
int64_t doc_count = array_offsets->GetRowCount();
AssertInfo(doc_bitset.size() == doc_count,
"Doc bitset size mismatch: {} vs {}",
doc_bitset.size(),
doc_count);
FixedVector<int32_t> element_offsets;
element_offsets.reserve(array_offsets->GetTotalElementCount());
// For each document that passes the filter, get all its element offsets
for (int64_t doc_id = 0; doc_id < doc_count; ++doc_id) {
if (doc_bitset[doc_id]) {
// Get element range for this document
auto [first_elem, last_elem] =
array_offsets->ElementIDRangeOfRow(doc_id);
// Add all element IDs for this document
for (int64_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
element_offsets.push_back(static_cast<int32_t>(elem_id));
}
}
}
return element_offsets;
}
std::pair<TargetBitmap, TargetBitmap>
PhyElementFilterBitsNode::EvaluateElementExpression(
FixedVector<int32_t>& element_offsets) {
@ -162,10 +130,8 @@ PhyElementFilterBitsNode::EvaluateElementExpression(
true);
tracer::AddEvent(fmt::format("input_elements: {}", element_offsets.size()));
// Use offset interface by passing element_offsets as third parameter
EvalCtx eval_ctx(operator_context_->get_exec_context(),
element_exprs_.get(),
&element_offsets);
// Use offset interface by passing element_offsets
EvalCtx eval_ctx(operator_context_->get_exec_context(), &element_offsets);
std::vector<VectorPtr> results;
element_exprs_->Eval(0, 1, true, eval_ctx, results);

View File

@ -77,9 +77,6 @@ class PhyElementFilterBitsNode : public Operator {
}
private:
FixedVector<int32_t>
DocBitsetToElementOffsets(const TargetBitmapView& doc_bitset);
std::pair<TargetBitmap, TargetBitmap>
EvaluateElementExpression(FixedVector<int32_t>& element_offsets);

View File

@ -74,7 +74,7 @@ PhyFilterBitsNode::GetOutput() {
std::chrono::high_resolution_clock::time_point scalar_start =
std::chrono::high_resolution_clock::now();
EvalCtx eval_ctx(operator_context_->get_exec_context(), exprs_.get());
EvalCtx eval_ctx(operator_context_->get_exec_context());
TargetBitmap bitset;
TargetBitmap valid_bitset;

View File

@ -160,7 +160,7 @@ PhyIterativeFilterNode::GetOutput() {
TargetBitmap bitset;
// get bitset of whole segment first
if (!is_native_supported_) {
EvalCtx eval_ctx(operator_context_->get_exec_context(), exprs_.get());
EvalCtx eval_ctx(operator_context_->get_exec_context());
TargetBitmap valid_bitset;
while (num_processed_rows_ < need_process_rows_) {
@ -225,8 +225,7 @@ PhyIterativeFilterNode::GetOutput() {
std::unordered_set<int64_t> unique_doc_ids;
for (auto& iterator : search_result.vector_iterators_.value()) {
EvalCtx eval_ctx(operator_context_->get_exec_context(),
exprs_.get());
EvalCtx eval_ctx(operator_context_->get_exec_context());
int topk = 0;
while (iterator->HasNext() && topk < unity_topk) {
FixedVector<int32_t> offsets;

View File

@ -111,7 +111,7 @@ PhyRescoresNode::GetOutput() {
filters.emplace_back(filter);
auto expr_set = std::make_unique<ExprSet>(filters, exec_context);
std::vector<VectorPtr> results;
EvalCtx eval_ctx(exec_context, expr_set.get());
EvalCtx eval_ctx(exec_context);
const auto& exprs = expr_set->exprs();
bool is_native_supported = true;

View File

@ -104,7 +104,7 @@ PhyVectorSearchNode::GetOutput() {
col_input->size());
auto [element_bitset, valid_element_bitset] =
array_offsets->RowBitsetToElementBitset(view, valid_view);
array_offsets->RowBitsetToElementBitset(view, valid_view, 0);
query_context_->set_active_element_count(element_bitset.size());

View File

@ -890,6 +890,49 @@ class JsonContainsExpr : public ITypeFilterExpr {
bool same_type_;
const std::vector<proto::plan::GenericValue> vals_;
};
// MatchType mirrors the protobuf MatchType enum
using MatchType = proto::plan::MatchType;
class MatchExpr : public ITypeFilterExpr {
public:
MatchExpr(const std::string& struct_name,
MatchType match_type,
int64_t count,
const TypedExprPtr& predicate)
: struct_name_(struct_name), match_type_(match_type), count_(count) {
inputs_.push_back(predicate);
}
std::string
ToString() const override {
return fmt::format("MatchExpr(struct_name={}, match_type={}, count={})",
struct_name_,
proto::plan::MatchType_Name(match_type_),
count_);
}
const std::string&
get_struct_name() const {
return struct_name_;
}
MatchType
get_match_type() const {
return match_type_;
}
int64_t
get_count() const {
return count_;
}
private:
std::string struct_name_;
MatchType match_type_;
int64_t count_; // Used for MatchLeast/MatchMost/MatchExact
};
} // namespace expr
} // namespace milvus

View File

@ -445,7 +445,7 @@ ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
static_cast<DataType>(column_info.element_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
@ -470,7 +470,7 @@ ProtoParser::ParseNullExprs(const proto::plan::NullExpr& expr_pb) {
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
static_cast<DataType>(column_info.element_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
@ -488,7 +488,7 @@ ProtoParser::ParseBinaryRangeExprs(
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
@ -510,7 +510,7 @@ ProtoParser::ParseTimestamptzArithCompareExprs(
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
@ -533,6 +533,16 @@ ProtoParser::ParseElementFilterExprs(
"ElementFilterExpr must be handled at PlanNode level.");
}
expr::TypedExprPtr
ProtoParser::ParseMatchExprs(const proto::plan::MatchExpr& expr_pb) {
auto struct_name = expr_pb.struct_name();
auto match_type = expr_pb.match_type();
auto count = expr_pb.count();
auto predicate = this->ParseExprs(expr_pb.predicate());
return std::make_shared<expr::MatchExpr>(
struct_name, match_type, count, predicate);
}
expr::TypedExprPtr
ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) {
std::vector<expr::TypedExprPtr> parameters;
@ -566,7 +576,7 @@ ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
if (left_column_info.is_element_level()) {
Assert(left_data_type == DataType::ARRAY);
Assert(left_field.get_element_type() ==
static_cast<DataType>(left_column_info.data_type()));
static_cast<DataType>(left_column_info.element_type()));
} else {
Assert(left_data_type ==
static_cast<DataType>(left_column_info.data_type()));
@ -580,7 +590,7 @@ ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
if (right_column_info.is_element_level()) {
Assert(right_data_type == DataType::ARRAY);
Assert(right_field.get_element_type() ==
static_cast<DataType>(right_column_info.data_type()));
static_cast<DataType>(right_column_info.element_type()));
} else {
Assert(right_data_type ==
static_cast<DataType>(right_column_info.data_type()));
@ -602,7 +612,7 @@ ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
@ -641,7 +651,7 @@ ProtoParser::ParseBinaryArithOpEvalRangeExprs(
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
static_cast<DataType>(column_info.element_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
@ -663,7 +673,7 @@ ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
if (column_info.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() ==
static_cast<DataType>(column_info.data_type()));
static_cast<DataType>(column_info.element_type()));
} else {
Assert(data_type == static_cast<DataType>(column_info.data_type()));
}
@ -680,7 +690,7 @@ ProtoParser::ParseJsonContainsExprs(
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
@ -715,7 +725,7 @@ ProtoParser::ParseGISFunctionFilterExprs(
if (columnInfo.is_element_level()) {
Assert(data_type == DataType::ARRAY);
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
} else {
Assert(data_type == (DataType)columnInfo.data_type());
}
@ -809,6 +819,10 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
"ElementFilterExpr should be handled at PlanNode level, "
"not in ParseExprs");
}
case ppe::kMatchExpr: {
result = ParseMatchExprs(expr_pb.match_expr());
break;
}
default: {
std::string s;
google::protobuf::TextFormat::PrintToString(expr_pb, &s);

View File

@ -109,6 +109,9 @@ class ProtoParser {
expr::TypedExprPtr
ParseElementFilterExprs(const proto::plan::ElementFilterExpr& expr_pb);
expr::TypedExprPtr
ParseMatchExprs(const proto::plan::MatchExpr& expr_pb);
expr::TypedExprPtr
ParseValueExprs(const proto::plan::ValueExpr& expr_pb);

View File

@ -166,8 +166,7 @@ gen_filter_res(milvus::plan::PlanNode* plan_node,
auto exprs_ =
std::make_unique<milvus::exec::ExprSet>(filters, exec_context.get());
std::vector<VectorPtr> results_;
milvus::exec::EvalCtx eval_ctx(exec_context.get(), exprs_.get());
eval_ctx.set_offset_input(offsets);
milvus::exec::EvalCtx eval_ctx(exec_context.get(), offsets);
exprs_->Eval(0, 1, true, eval_ctx, results_);
auto col_vec = std::dynamic_pointer_cast<milvus::ColumnVector>(results_[0]);

View File

@ -20,6 +20,11 @@ expr:
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
| RANDOMSAMPLE'(' expr ')' # RandomSample
| ElementFilter'('Identifier',' expr')' # ElementFilter
| MATCH_ALL'(' Identifier ',' expr ')' # MatchAll
| MATCH_ANY'(' Identifier ',' expr ')' # MatchAny
| MATCH_LEAST'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchLeast
| MATCH_MOST'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchMost
| MATCH_EXACT'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchExact
| expr POW expr # Power
| op = (ADD | SUB | BNOT | NOT) expr # Unary
// | '(' typeName ')' expr # Cast
@ -80,9 +85,15 @@ EXISTS: 'exists' | 'EXISTS';
TEXTMATCH: 'text_match'|'TEXT_MATCH';
PHRASEMATCH: 'phrase_match'|'PHRASE_MATCH';
RANDOMSAMPLE: 'random_sample' | 'RANDOM_SAMPLE';
MATCH_ALL: 'match_all' | 'MATCH_ALL';
MATCH_ANY: 'match_any' | 'MATCH_ANY';
MATCH_LEAST: 'match_least' | 'MATCH_LEAST';
MATCH_MOST: 'match_most' | 'MATCH_MOST';
MATCH_EXACT: 'match_exact' | 'MATCH_EXACT';
INTERVAL: 'interval' | 'INTERVAL';
ISO: 'iso' | 'ISO';
MINIMUM_SHOULD_MATCH: 'minimum_should_match' | 'MINIMUM_SHOULD_MATCH';
THRESHOLD: 'threshold' | 'THRESHOLD';
ASSIGN: '=';
ADD: '+';

View File

@ -64,7 +64,8 @@ func FillTermExpressionValue(expr *planpb.TermExpr, templateValues map[string]*p
}
dataType := expr.GetColumnInfo().GetDataType()
if typeutil.IsArrayType(dataType) {
if len(expr.GetColumnInfo().GetNestedPath()) != 0 {
// Use element type if accessing array element
if len(expr.GetColumnInfo().GetNestedPath()) != 0 || expr.GetColumnInfo().GetIsElementLevel() {
dataType = expr.GetColumnInfo().GetElementType()
}
}
@ -91,7 +92,8 @@ func FillUnaryRangeExpressionValue(expr *planpb.UnaryRangeExpr, templateValues m
dataType := expr.GetColumnInfo().GetDataType()
if typeutil.IsArrayType(dataType) {
if len(expr.GetColumnInfo().GetNestedPath()) != 0 {
// Use element type if accessing array element
if len(expr.GetColumnInfo().GetNestedPath()) != 0 || expr.GetColumnInfo().GetIsElementLevel() {
dataType = expr.GetColumnInfo().GetElementType()
}
}
@ -107,7 +109,8 @@ func FillUnaryRangeExpressionValue(expr *planpb.UnaryRangeExpr, templateValues m
func FillBinaryRangeExpressionValue(expr *planpb.BinaryRangeExpr, templateValues map[string]*planpb.GenericValue) error {
var ok bool
dataType := expr.GetColumnInfo().GetDataType()
if typeutil.IsArrayType(dataType) && len(expr.GetColumnInfo().GetNestedPath()) != 0 {
// Use element type if accessing array element
if typeutil.IsArrayType(dataType) && (len(expr.GetColumnInfo().GetNestedPath()) != 0 || expr.GetColumnInfo().GetIsElementLevel()) {
dataType = expr.GetColumnInfo().GetElementType()
}
lowerValue := expr.GetLowerValue()

File diff suppressed because one or more lines are too long

View File

@ -16,56 +16,62 @@ EXISTS=15
TEXTMATCH=16
PHRASEMATCH=17
RANDOMSAMPLE=18
INTERVAL=19
ISO=20
MINIMUM_SHOULD_MATCH=21
ASSIGN=22
ADD=23
SUB=24
MUL=25
DIV=26
MOD=27
POW=28
SHL=29
SHR=30
BAND=31
BOR=32
BXOR=33
AND=34
OR=35
ISNULL=36
ISNOTNULL=37
BNOT=38
NOT=39
IN=40
EmptyArray=41
JSONContains=42
JSONContainsAll=43
JSONContainsAny=44
ArrayContains=45
ArrayContainsAll=46
ArrayContainsAny=47
ArrayLength=48
ElementFilter=49
STEuqals=50
STTouches=51
STOverlaps=52
STCrosses=53
STContains=54
STIntersects=55
STWithin=56
STDWithin=57
STIsValid=58
BooleanConstant=59
IntegerConstant=60
FloatingConstant=61
Identifier=62
Meta=63
StringLiteral=64
JSONIdentifier=65
StructSubFieldIdentifier=66
Whitespace=67
Newline=68
MATCH_ALL=19
MATCH_ANY=20
MATCH_LEAST=21
MATCH_MOST=22
MATCH_EXACT=23
INTERVAL=24
ISO=25
MINIMUM_SHOULD_MATCH=26
THRESHOLD=27
ASSIGN=28
ADD=29
SUB=30
MUL=31
DIV=32
MOD=33
POW=34
SHL=35
SHR=36
BAND=37
BOR=38
BXOR=39
AND=40
OR=41
ISNULL=42
ISNOTNULL=43
BNOT=44
NOT=45
IN=46
EmptyArray=47
JSONContains=48
JSONContainsAll=49
JSONContainsAny=50
ArrayContains=51
ArrayContainsAll=52
ArrayContainsAny=53
ArrayLength=54
ElementFilter=55
STEuqals=56
STTouches=57
STOverlaps=58
STCrosses=59
STContains=60
STIntersects=61
STWithin=62
STDWithin=63
STIsValid=64
BooleanConstant=65
IntegerConstant=66
FloatingConstant=67
Identifier=68
Meta=69
StringLiteral=70
JSONIdentifier=71
StructSubFieldIdentifier=72
Whitespace=73
Newline=74
'('=1
')'=2
'['=3
@ -79,17 +85,17 @@ Newline=68
'>='=11
'=='=12
'!='=13
'='=22
'+'=23
'-'=24
'*'=25
'/'=26
'%'=27
'**'=28
'<<'=29
'>>'=30
'&'=31
'|'=32
'^'=33
'~'=38
'$meta'=63
'='=28
'+'=29
'-'=30
'*'=31
'/'=32
'%'=33
'**'=34
'<<'=35
'>>'=36
'&'=37
'|'=38
'^'=39
'~'=44
'$meta'=69

File diff suppressed because one or more lines are too long

View File

@ -16,56 +16,62 @@ EXISTS=15
TEXTMATCH=16
PHRASEMATCH=17
RANDOMSAMPLE=18
INTERVAL=19
ISO=20
MINIMUM_SHOULD_MATCH=21
ASSIGN=22
ADD=23
SUB=24
MUL=25
DIV=26
MOD=27
POW=28
SHL=29
SHR=30
BAND=31
BOR=32
BXOR=33
AND=34
OR=35
ISNULL=36
ISNOTNULL=37
BNOT=38
NOT=39
IN=40
EmptyArray=41
JSONContains=42
JSONContainsAll=43
JSONContainsAny=44
ArrayContains=45
ArrayContainsAll=46
ArrayContainsAny=47
ArrayLength=48
ElementFilter=49
STEuqals=50
STTouches=51
STOverlaps=52
STCrosses=53
STContains=54
STIntersects=55
STWithin=56
STDWithin=57
STIsValid=58
BooleanConstant=59
IntegerConstant=60
FloatingConstant=61
Identifier=62
Meta=63
StringLiteral=64
JSONIdentifier=65
StructSubFieldIdentifier=66
Whitespace=67
Newline=68
MATCH_ALL=19
MATCH_ANY=20
MATCH_LEAST=21
MATCH_MOST=22
MATCH_EXACT=23
INTERVAL=24
ISO=25
MINIMUM_SHOULD_MATCH=26
THRESHOLD=27
ASSIGN=28
ADD=29
SUB=30
MUL=31
DIV=32
MOD=33
POW=34
SHL=35
SHR=36
BAND=37
BOR=38
BXOR=39
AND=40
OR=41
ISNULL=42
ISNOTNULL=43
BNOT=44
NOT=45
IN=46
EmptyArray=47
JSONContains=48
JSONContainsAll=49
JSONContainsAny=50
ArrayContains=51
ArrayContainsAll=52
ArrayContainsAny=53
ArrayLength=54
ElementFilter=55
STEuqals=56
STTouches=57
STOverlaps=58
STCrosses=59
STContains=60
STIntersects=61
STWithin=62
STDWithin=63
STIsValid=64
BooleanConstant=65
IntegerConstant=66
FloatingConstant=67
Identifier=68
Meta=69
StringLiteral=70
JSONIdentifier=71
StructSubFieldIdentifier=72
Whitespace=73
Newline=74
'('=1
')'=2
'['=3
@ -79,17 +85,17 @@ Newline=68
'>='=11
'=='=12
'!='=13
'='=22
'+'=23
'-'=24
'*'=25
'/'=26
'%'=27
'**'=28
'<<'=29
'>>'=30
'&'=31
'|'=32
'^'=33
'~'=38
'$meta'=63
'='=28
'+'=29
'-'=30
'*'=31
'/'=32
'%'=33
'**'=34
'<<'=35
'>>'=36
'&'=37
'|'=38
'^'=39
'~'=44
'$meta'=69

View File

@ -11,6 +11,10 @@ func (v *BasePlanVisitor) VisitString(ctx *StringContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchAny(ctx *MatchAnyContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitFloating(ctx *FloatingContext) interface{} {
return v.VisitChildren(ctx)
}
@ -59,6 +63,10 @@ func (v *BasePlanVisitor) VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchLeast(ctx *MatchLeastContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitArrayLength(ctx *ArrayLengthContext) interface{} {
return v.VisitChildren(ctx)
}
@ -83,6 +91,10 @@ func (v *BasePlanVisitor) VisitRange(ctx *RangeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchAll(ctx *MatchAllContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
return v.VisitChildren(ctx)
}
@ -103,6 +115,10 @@ func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchMost(ctx *MatchMostContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{} {
return v.VisitChildren(ctx)
}
@ -115,6 +131,10 @@ func (v *BasePlanVisitor) VisitParens(ctx *ParensContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchExact(ctx *MatchExactContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{} {
return v.VisitChildren(ctx)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -10,6 +10,9 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#String.
VisitString(ctx *StringContext) interface{}
// Visit a parse tree produced by PlanParser#MatchAny.
VisitMatchAny(ctx *MatchAnyContext) interface{}
// Visit a parse tree produced by PlanParser#Floating.
VisitFloating(ctx *FloatingContext) interface{}
@ -46,6 +49,9 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#PhraseMatch.
VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
// Visit a parse tree produced by PlanParser#MatchLeast.
VisitMatchLeast(ctx *MatchLeastContext) interface{}
// Visit a parse tree produced by PlanParser#ArrayLength.
VisitArrayLength(ctx *ArrayLengthContext) interface{}
@ -64,6 +70,9 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#Range.
VisitRange(ctx *RangeContext) interface{}
// Visit a parse tree produced by PlanParser#MatchAll.
VisitMatchAll(ctx *MatchAllContext) interface{}
// Visit a parse tree produced by PlanParser#STIsValid.
VisitSTIsValid(ctx *STIsValidContext) interface{}
@ -79,6 +88,9 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#STOverlaps.
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
// Visit a parse tree produced by PlanParser#MatchMost.
VisitMatchMost(ctx *MatchMostContext) interface{}
// Visit a parse tree produced by PlanParser#JSONIdentifier.
VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{}
@ -88,6 +100,9 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#Parens.
VisitParens(ctx *ParensContext) interface{}
// Visit a parse tree produced by PlanParser#MatchExact.
VisitMatchExact(ctx *MatchExactContext) interface{}
// Visit a parse tree produced by PlanParser#JSONContainsAll.
VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{}

View File

@ -192,7 +192,7 @@ func (v *ParserVisitor) VisitString(ctx *parser.StringContext) interface{} {
}
func checkDirectComparisonBinaryField(columnInfo *planpb.ColumnInfo) error {
if typeutil.IsArrayType(columnInfo.GetDataType()) && len(columnInfo.GetNestedPath()) == 0 {
if typeutil.IsArrayType(columnInfo.GetDataType()) && len(columnInfo.GetNestedPath()) == 0 && !columnInfo.GetIsElementLevel() {
return errors.New("can not comparisons array fields directly")
}
return nil
@ -664,6 +664,10 @@ func isElementFilterExpr(expr *ExprWithType) bool {
return expr.expr.GetElementFilterExpr() != nil
}
func isMatchExpr(expr *ExprWithType) bool {
return expr.expr.GetMatchExpr() != nil
}
const EPSILON = 1e-10
func (v *ParserVisitor) VisitRandomSample(ctx *parser.RandomSampleContext) interface{} {
@ -715,7 +719,10 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} {
}
dataType := columnInfo.GetDataType()
if typeutil.IsArrayType(dataType) && len(columnInfo.GetNestedPath()) != 0 {
// Use element type for IN operation in two cases:
// 1. Array with nested path (e.g., arr[0] IN [1, 2, 3])
// 2. Array with element level flag (e.g., $[intField] IN [1, 2] in MATCH_ALL/ElementFilter)
if typeutil.IsArrayType(dataType) && (len(columnInfo.GetNestedPath()) != 0 || columnInfo.GetIsElementLevel()) {
dataType = columnInfo.GetElementType()
}
@ -2298,9 +2305,9 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
// Remove "$[" prefix and "]" suffix
fieldName := tokenText[2 : len(tokenText)-1]
// Check if we're inside an ElementFilter context
// Check if we're inside an ElementFilter or MATCH_* context
if v.currentStructArrayField == "" {
return fmt.Errorf("$[%s] syntax can only be used inside ElementFilter", fieldName)
return fmt.Errorf("$[%s] syntax can only be used inside ElementFilter or MATCH_*", fieldName)
}
// Construct full field name for struct array field
@ -2311,7 +2318,7 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
return fmt.Errorf("array field not found: %s, error: %s", fullFieldName, err)
}
// In element-level context, data_type should be the element type
// In element-level context, use Array as storage type, element type for operations
elementType := field.GetElementType()
return &ExprWithType{
@ -2320,12 +2327,12 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
ColumnExpr: &planpb.ColumnExpr{
Info: &planpb.ColumnInfo{
FieldId: field.FieldID,
DataType: elementType, // Use element type, not storage type
DataType: schemapb.DataType_Array, // Storage type is Array
IsPrimaryKey: field.IsPrimaryKey,
IsAutoID: field.AutoID,
IsPartitionKey: field.IsPartitionKey,
IsClusteringKey: field.IsClusteringKey,
ElementType: elementType,
ElementType: elementType, // Element type for operations
Nullable: field.GetNullable(),
IsElementLevel: true, // Mark as element-level access
},
@ -2336,3 +2343,113 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
nodeDependent: true,
}
}
// parseMatchExpr is a helper function for parsing match expressions
// matchType: the type of match operation (MatchAll, MatchAny, MatchLeast, MatchMost)
// count: for MatchLeast/MatchMost, the count parameter (N); for MatchAll/MatchAny, this is ignored (0)
func (v *ParserVisitor) parseMatchExpr(structArrayFieldName string, exprCtx parser.IExprContext, matchType planpb.MatchType, count int64, funcName string) interface{} {
// Check for nested match expression - not allowed
if v.currentStructArrayField != "" {
return fmt.Errorf("nested %s is not supported, already inside match expression for field: %s", funcName, v.currentStructArrayField)
}
// Set current context for element expression parsing
v.currentStructArrayField = structArrayFieldName
defer func() { v.currentStructArrayField = "" }()
// Parse the predicate expression
predicate := exprCtx.Accept(v)
if err := getError(predicate); err != nil {
return fmt.Errorf("cannot parse predicate expression: %s, error: %s", exprCtx.GetText(), err)
}
predicateExpr := getExpr(predicate)
if predicateExpr == nil {
return fmt.Errorf("invalid predicate expression in %s: %s", funcName, exprCtx.GetText())
}
// Build MatchExpr proto
return &ExprWithType{
expr: &planpb.Expr{
Expr: &planpb.Expr_MatchExpr{
MatchExpr: &planpb.MatchExpr{
StructName: structArrayFieldName,
Predicate: predicateExpr.expr,
MatchType: matchType,
Count: count,
},
},
},
dataType: schemapb.DataType_Bool,
}
}
// VisitMatchAll handles MATCH_ALL expressions
// Syntax: MATCH_ALL(structArrayField, $[intField] == 1 && $[strField] == "aaa")
// All elements must match the predicate
func (v *ParserVisitor) VisitMatchAll(ctx *parser.MatchAllContext) interface{} {
structArrayFieldName := ctx.Identifier().GetText()
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchAll, 0, "MATCH_ALL")
}
// VisitMatchAny handles MATCH_ANY expressions
// Syntax: MATCH_ANY(structArrayField, $[intField] == 1 && $[strField] == "aaa")
// At least one element must match the predicate
func (v *ParserVisitor) VisitMatchAny(ctx *parser.MatchAnyContext) interface{} {
structArrayFieldName := ctx.Identifier().GetText()
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchAny, 0, "MATCH_ANY")
}
// VisitMatchLeast handles MATCH_LEAST expressions
// Syntax: MATCH_LEAST(structArrayField, $[intField] == 1 && $[strField] == "aaa", threshold=N)
// At least N elements must match the predicate
func (v *ParserVisitor) VisitMatchLeast(ctx *parser.MatchLeastContext) interface{} {
structArrayFieldName := ctx.Identifier().GetText()
countStr := ctx.IntegerConstant().GetText()
count, err := strconv.ParseInt(countStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid count in MATCH_LEAST: %s", countStr)
}
if count <= 0 {
return fmt.Errorf("count in MATCH_LEAST must be positive, got: %d", count)
}
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchLeast, count, "MATCH_LEAST")
}
// VisitMatchMost handles MATCH_MOST expressions
// Syntax: MATCH_MOST(structArrayField, $[intField] == 1 && $[strField] == "aaa", threshold=N)
// At most N elements must match the predicate
func (v *ParserVisitor) VisitMatchMost(ctx *parser.MatchMostContext) interface{} {
structArrayFieldName := ctx.Identifier().GetText()
countStr := ctx.IntegerConstant().GetText()
count, err := strconv.ParseInt(countStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid count in MATCH_MOST: %s", countStr)
}
if count < 0 {
return fmt.Errorf("count in MATCH_MOST cannot be negative, got: %d", count)
}
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchMost, count, "MATCH_MOST")
}
// VisitMatchExact handles MATCH_EXACT expressions
// Syntax: MATCH_EXACT(structArrayField, $[intField] == 1 && $[strField] == "aaa", threshold=N)
// Exactly N elements must match the predicate
func (v *ParserVisitor) VisitMatchExact(ctx *parser.MatchExactContext) interface{} {
structArrayFieldName := ctx.Identifier().GetText()
countStr := ctx.IntegerConstant().GetText()
count, err := strconv.ParseInt(countStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid count in MATCH_EXACT: %s", countStr)
}
if count < 0 {
return fmt.Errorf("count in MATCH_EXACT cannot be negative, got: %d", count)
}
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchExact, count, "MATCH_EXACT")
}

View File

@ -2557,3 +2557,150 @@ func TestExpr_ElementFilter(t *testing.T) {
assertInvalidExpr(t, helper, expr)
}
}
func TestExpr_Match(t *testing.T) {
schema := newTestSchema(true)
helper, err := typeutil.CreateSchemaHelper(schema)
assert.NoError(t, err)
// Valid MATCH_ALL expressions
validExprs := []string{
// MATCH_ALL: all elements must match
`MATCH_ALL(struct_array, $[sub_int] > 1)`,
`MATCH_ALL(struct_array, $[sub_int] == 100)`,
`MATCH_ALL(struct_array, $[sub_str] == "aaa")`,
`MATCH_ALL(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100)`,
`MATCH_ALL(struct_array, $[sub_str] != "" || $[sub_int] >= 0)`,
// MATCH_ANY: at least one element must match
`MATCH_ANY(struct_array, $[sub_int] > 1)`,
`MATCH_ANY(struct_array, $[sub_int] == 100)`,
`MATCH_ANY(struct_array, $[sub_str] == "aaa")`,
`MATCH_ANY(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100)`,
// MATCH_LEAST: at least N elements must match
`MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=3)`,
`MATCH_LEAST(struct_array, $[sub_str] == "aaa", threshold=1)`,
`MATCH_LEAST(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100, threshold=2)`,
// MATCH_MOST: at most N elements must match
`MATCH_MOST(struct_array, $[sub_int] > 1, threshold=3)`,
`MATCH_MOST(struct_array, $[sub_str] == "aaa", threshold=0)`,
`MATCH_MOST(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100, threshold=5)`,
// MATCH_EXACT: exactly N elements must match
`MATCH_EXACT(struct_array, $[sub_int] > 1, threshold=2)`,
`MATCH_EXACT(struct_array, $[sub_str] == "aaa", threshold=0)`,
`MATCH_EXACT(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100, threshold=3)`,
// Combined with other expressions (match must be last)
`Int64Field > 0 && MATCH_ALL(struct_array, $[sub_int] > 1)`,
`Int64Field > 0 && MATCH_ANY(struct_array, $[sub_str] == "test")`,
`Int64Field > 0 && MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=2)`,
// Complex predicates
`MATCH_ALL(struct_array, ($[sub_int] > 0 && $[sub_int] < 100) || $[sub_str] == "default")`,
`MATCH_ANY(struct_array, !($[sub_int] < 0))`,
// Case insensitivity
`match_all(struct_array, $[sub_int] > 1)`,
`match_any(struct_array, $[sub_int] > 1)`,
`match_least(struct_array, $[sub_int] > 1, threshold=2)`,
`match_most(struct_array, $[sub_int] > 1, threshold=2)`,
`match_exact(struct_array, $[sub_int] > 1, threshold=2)`,
// Multiple match expressions with logical operators
`MATCH_ALL(struct_array, $[sub_int] > 1) || MATCH_ANY(struct_array, $[sub_str] == "test")`,
`MATCH_ALL(struct_array, $[sub_int] > 1) && MATCH_ANY(struct_array, $[sub_str] == "test")`,
`MATCH_ANY(struct_array, $[sub_int] > 1) || Int64Field > 0`,
`MATCH_ALL(struct_array, $[sub_int] > 1) && Int64Field > 0`,
}
for _, expr := range validExprs {
assertValidExpr(t, helper, expr)
}
// Test proto structure assertions
t.Run("MatchAll_Proto", func(t *testing.T) {
expr, err := ParseExpr(helper, `MATCH_ALL(struct_array, $[sub_int] > 1)`, nil)
assert.NoError(t, err)
assert.NotNil(t, expr.GetMatchExpr())
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
assert.Equal(t, planpb.MatchType_MatchAll, expr.GetMatchExpr().GetMatchType())
assert.Equal(t, int64(0), expr.GetMatchExpr().GetCount())
})
t.Run("MatchAny_Proto", func(t *testing.T) {
expr, err := ParseExpr(helper, `MATCH_ANY(struct_array, $[sub_str] == "aaa")`, nil)
assert.NoError(t, err)
assert.NotNil(t, expr.GetMatchExpr())
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
assert.Equal(t, planpb.MatchType_MatchAny, expr.GetMatchExpr().GetMatchType())
assert.Equal(t, int64(0), expr.GetMatchExpr().GetCount())
})
t.Run("MatchLeast_Proto", func(t *testing.T) {
expr, err := ParseExpr(helper, `MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=3)`, nil)
assert.NoError(t, err)
assert.NotNil(t, expr.GetMatchExpr())
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
assert.Equal(t, planpb.MatchType_MatchLeast, expr.GetMatchExpr().GetMatchType())
assert.Equal(t, int64(3), expr.GetMatchExpr().GetCount())
})
t.Run("MatchMost_Proto", func(t *testing.T) {
expr, err := ParseExpr(helper, `MATCH_MOST(struct_array, $[sub_str] == "aaa", threshold=5)`, nil)
assert.NoError(t, err)
assert.NotNil(t, expr.GetMatchExpr())
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
assert.Equal(t, planpb.MatchType_MatchMost, expr.GetMatchExpr().GetMatchType())
assert.Equal(t, int64(5), expr.GetMatchExpr().GetCount())
})
t.Run("MatchExact_Proto", func(t *testing.T) {
expr, err := ParseExpr(helper, `MATCH_EXACT(struct_array, $[sub_int] == 100, threshold=2)`, nil)
assert.NoError(t, err)
assert.NotNil(t, expr.GetMatchExpr())
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
assert.Equal(t, planpb.MatchType_MatchExact, expr.GetMatchExpr().GetMatchType())
assert.Equal(t, int64(2), expr.GetMatchExpr().GetCount())
})
// Invalid expressions
invalidExprs := []string{
// Nested match expressions not allowed
`MATCH_ALL(struct_array, MATCH_ANY(struct_array, $[sub_int] > 1))`,
`MATCH_ANY(struct_array, $[sub_int] > 1 && MATCH_ALL(struct_array, $[sub_str] == "1"))`,
// $[field] syntax outside match context
`$[sub_int] > 1`,
`Int64Field > 0 && $[sub_int] > 1`,
// Non-existent fields
`MATCH_ALL(struct_array, $[non_existent_field] > 1)`,
`MATCH_ALL(non_existent_array, $[sub_int] > 1)`,
// Missing parameters
`MATCH_ALL(struct_array)`,
`MATCH_ALL()`,
`MATCH_ANY(struct_array)`,
`MATCH_ANY()`,
`MATCH_LEAST(struct_array, $[sub_int] > 1)`, // missing count
`MATCH_MOST(struct_array, $[sub_int] > 1)`, // missing count
`MATCH_EXACT(struct_array, $[sub_int] > 1)`, // missing count
// MATCH_ALL/MATCH_ANY should not have count parameter
`MATCH_ALL(struct_array, $[sub_int] > 1, 3)`,
`MATCH_ANY(struct_array, $[sub_int] > 1, 2)`,
// Invalid count values
`MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=0)`, // count must be positive for MATCH_LEAST
`MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=-1)`, // negative count
`MATCH_MOST(struct_array, $[sub_int] > 1, threshold=-1)`, // negative count
`MATCH_EXACT(struct_array, $[sub_int] > 1, threshold=-1)`, // negative count
}
for _, expr := range invalidExprs {
assertInvalidExpr(t, helper, expr)
}
}

View File

@ -362,8 +362,14 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr,
func handleCompareRightValue(op planpb.OpType, left *ExprWithType, right *planpb.ValueExpr) (*planpb.Expr, error) {
dataType := left.dataType
if typeutil.IsArrayType(dataType) && len(toColumnInfo(left).GetNestedPath()) != 0 {
dataType = toColumnInfo(left).GetElementType()
columnInfo := toColumnInfo(left)
// Use element type for casting in two cases:
// 1. Array with nested path (e.g., arr[0])
// 2. Array with element level flag (e.g., $[intField] in MATCH_ALL/ElementFilter)
if typeutil.IsArrayType(dataType) && columnInfo != nil &&
(len(columnInfo.GetNestedPath()) != 0 || columnInfo.GetIsElementLevel()) {
dataType = columnInfo.GetElementType()
}
if !left.expr.GetIsTemplate() && !isTemplateExpr(right) {
@ -378,7 +384,6 @@ func handleCompareRightValue(op planpb.OpType, left *ExprWithType, right *planpb
return handleBinaryArithExpr(op, leftArithExpr, left.dataType, right)
}
columnInfo := toColumnInfo(left)
if columnInfo == nil {
return nil, errors.New("not supported to combine multiple fields")
}

View File

@ -252,6 +252,22 @@ message ElementFilterExpr {
Expr predicate = 3;
}
// MatchType defines the type of match operation for struct array queries
enum MatchType {
MatchAll = 0; // All elements must match the predicate
MatchAny = 1; // At least one element matches the predicate
MatchLeast = 2; // At least N elements match the predicate
MatchMost = 3; // At most N elements match the predicate
MatchExact = 4; // Exactly N elements match the predicate
}
message MatchExpr {
string struct_name = 1; // The struct array field name (e.g., struct_array)
Expr predicate = 2; // The condition expression using $[field] syntax (e.g., $[intField] == 1 && $[strField] == "aaa")
MatchType match_type = 3; // Type of match operation
int64 count = 4; // For MatchLeast/MatchMost: the count parameter (N)
}
message AlwaysTrueExpr {}
message Interval {
@ -293,6 +309,7 @@ message Expr {
GISFunctionFilterExpr gisfunction_filter_expr = 17;
TimestamptzArithCompareExpr timestamptz_arith_compare_expr = 18;
ElementFilterExpr element_filter_expr = 19;
MatchExpr match_expr = 21;
};
bool is_template = 20;
}

File diff suppressed because it is too large Load Diff