mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
feat: support match operator family (#46518)
issue: https://github.com/milvus-io/milvus/issues/46517 ref: https://github.com/milvus-io/milvus/issues/42148 This PR supports match operator family with struct array and brute force search only. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> - Core invariant: match operators only target struct-array element-level predicates and assume callers provide a correct row_start so element indices form a contiguous range; IArrayOffsets implementations convert row-level bitmaps/rows (starting at row_start) into element-level bitmaps or a contiguous element-offset vector used by brute-force evaluation. - New capability added: end-to-end support for MATCH_* semantics (match_any, match_all, match_least, match_most, match_exact) — parser (grammar + proto), planner (ParseMatchExprs), expr model (expr::MatchExpr), compilation (Expr→PhyMatchFilterExpr), execution (PhyMatchFilterExpr::Eval uses element offsets/bitmaps), and unit tests (MatchExprTest + parser tests). Implementation currently works for struct-array inputs and uses brute-force element counting via RowBitsetToElementOffsets/RowBitsetToElementBitset. - Logic removed or simplified and why: removed the ad-hoc DocBitsetToElementOffsets helper and consolidated offset/bitset derivation into IArrayOffsets::RowBitsetToElementOffsets and a row_start-aware RowBitsetToElementBitset, and removed EvalCtx overloads that embedded ExprSet (now EvalCtx(exec_ctx, offset_input)). This centralizes array-layout logic in ArrayOffsets and removes duplicated offset conversion and EvalCtx variants that were redundant for element-level evaluation. - No data loss / no behavior regression: persistent formats are unchanged (no proto storage or on-disk layout changed); callers were updated to supply row_start and now route through the centralized ArrayOffsets APIs which still use the authoritative row_to_element_start_ mapping, preserving exact element index mappings. Eval logic changes are limited to in-memory plumbing (how offsets/bitmaps are produced and how EvalCtx is constructed); expression evaluation still invokes exprs_->Eval where needed, so existing behavior and stored data remain intact. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com> Signed-off-by: SpadeA-Tang <tangchenjie1210@gmail.com>
This commit is contained in:
parent
0d70d2b98c
commit
0114bd1dc9
@ -46,25 +46,73 @@ ArrayOffsetsSealed::ElementIDRangeOfRow(int32_t row_id) const {
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
ArrayOffsetsSealed::RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const {
|
||||
int64_t row_count = GetRowCount();
|
||||
int64_t element_count = GetTotalElementCount();
|
||||
const TargetBitmapView& valid_row_bitset,
|
||||
int64_t row_start) const {
|
||||
int64_t row_count = row_bitset.size();
|
||||
AssertInfo(row_start >= 0 && row_start + row_count <= GetRowCount(),
|
||||
"row range out of bounds: row_start={}, row_count={}, "
|
||||
"total_rows={}",
|
||||
row_start,
|
||||
row_count,
|
||||
GetRowCount());
|
||||
|
||||
int64_t element_start = row_to_element_start_[row_start];
|
||||
int64_t element_end = row_to_element_start_[row_start + row_count];
|
||||
int64_t element_count = element_end - element_start;
|
||||
|
||||
TargetBitmap element_bitset(element_count);
|
||||
TargetBitmap valid_element_bitset(element_count);
|
||||
|
||||
for (int64_t row_id = 0; row_id < row_count; ++row_id) {
|
||||
int64_t start = row_to_element_start_[row_id];
|
||||
int64_t end = row_to_element_start_[row_id + 1];
|
||||
for (int64_t i = 0; i < row_count; ++i) {
|
||||
int64_t row_id = row_start + i;
|
||||
int64_t start = row_to_element_start_[row_id] - element_start;
|
||||
int64_t end = row_to_element_start_[row_id + 1] - element_start;
|
||||
if (start < end) {
|
||||
element_bitset.set(start, end - start, row_bitset[row_id]);
|
||||
valid_element_bitset.set(
|
||||
start, end - start, valid_row_bitset[row_id]);
|
||||
element_bitset.set(start, end - start, row_bitset[i]);
|
||||
valid_element_bitset.set(start, end - start, valid_row_bitset[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return {std::move(element_bitset), std::move(valid_element_bitset)};
|
||||
}
|
||||
|
||||
FixedVector<int32_t>
|
||||
ArrayOffsetsSealed::RowBitsetToElementOffsets(
|
||||
const TargetBitmapView& row_bitset, int64_t row_start) const {
|
||||
int64_t row_count = row_bitset.size();
|
||||
AssertInfo(row_start >= 0 && row_start + row_count <= GetRowCount(),
|
||||
"row range out of bounds: row_start={}, row_count={}, "
|
||||
"total_rows={}",
|
||||
row_start,
|
||||
row_count,
|
||||
GetRowCount());
|
||||
|
||||
int64_t selected_rows = row_bitset.count();
|
||||
FixedVector<int32_t> element_offsets;
|
||||
if (selected_rows == 0) {
|
||||
return element_offsets;
|
||||
}
|
||||
|
||||
int64_t avg_elem_per_row =
|
||||
static_cast<int64_t>(element_row_ids_.size()) /
|
||||
(static_cast<int64_t>(row_to_element_start_.size()) - 1);
|
||||
|
||||
element_offsets.reserve(selected_rows * avg_elem_per_row);
|
||||
|
||||
for (int64_t i = 0; i < row_count; ++i) {
|
||||
if (row_bitset[i]) {
|
||||
int64_t row_id = row_start + i;
|
||||
int32_t first_elem = row_to_element_start_[row_id];
|
||||
int32_t last_elem = row_to_element_start_[row_id + 1];
|
||||
for (int32_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
|
||||
element_offsets.push_back(elem_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return element_offsets;
|
||||
}
|
||||
|
||||
std::shared_ptr<ArrayOffsetsSealed>
|
||||
ArrayOffsetsSealed::BuildFromSegment(const void* segment,
|
||||
const FieldMeta& field_meta) {
|
||||
@ -193,23 +241,73 @@ ArrayOffsetsGrowing::ElementIDRangeOfRow(int32_t row_id) const {
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
ArrayOffsetsGrowing::RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const {
|
||||
const TargetBitmapView& valid_row_bitset,
|
||||
int64_t row_start) const {
|
||||
std::shared_lock lock(mutex_);
|
||||
|
||||
int64_t element_count = element_row_ids_.size();
|
||||
int64_t row_count = row_bitset.size();
|
||||
AssertInfo(row_start >= 0 && row_start + row_count <= committed_row_count_,
|
||||
"row range out of bounds: row_start={}, row_count={}, "
|
||||
"committed_rows={}",
|
||||
row_start,
|
||||
row_count,
|
||||
committed_row_count_);
|
||||
|
||||
int64_t element_start = row_to_element_start_[row_start];
|
||||
int64_t element_end = row_to_element_start_[row_start + row_count];
|
||||
int64_t element_count = element_end - element_start;
|
||||
|
||||
TargetBitmap element_bitset(element_count);
|
||||
TargetBitmap valid_element_bitset(element_count);
|
||||
|
||||
// Direct access to element_row_ids_, no virtual function calls
|
||||
for (size_t elem_id = 0; elem_id < element_row_ids_.size(); ++elem_id) {
|
||||
for (int64_t elem_id = element_start; elem_id < element_end; ++elem_id) {
|
||||
auto row_id = element_row_ids_[elem_id];
|
||||
element_bitset[elem_id] = row_bitset[row_id];
|
||||
valid_element_bitset[elem_id] = valid_row_bitset[row_id];
|
||||
int64_t bitset_idx = row_id - row_start;
|
||||
element_bitset[elem_id - element_start] = row_bitset[bitset_idx];
|
||||
valid_element_bitset[elem_id - element_start] =
|
||||
valid_row_bitset[bitset_idx];
|
||||
}
|
||||
|
||||
return {std::move(element_bitset), std::move(valid_element_bitset)};
|
||||
}
|
||||
|
||||
FixedVector<int32_t>
|
||||
ArrayOffsetsGrowing::RowBitsetToElementOffsets(
|
||||
const TargetBitmapView& row_bitset, int64_t row_start) const {
|
||||
std::shared_lock lock(mutex_);
|
||||
|
||||
int64_t row_count = row_bitset.size();
|
||||
AssertInfo(row_start >= 0 && row_start + row_count <= committed_row_count_,
|
||||
"row range out of bounds: row_start={}, row_count={}, "
|
||||
"committed_rows={}",
|
||||
row_start,
|
||||
row_count,
|
||||
committed_row_count_);
|
||||
|
||||
int64_t selected_rows = row_bitset.count();
|
||||
FixedVector<int32_t> element_offsets;
|
||||
if (selected_rows == 0) {
|
||||
return element_offsets;
|
||||
}
|
||||
int64_t avg_elem_per_row =
|
||||
static_cast<int64_t>(element_row_ids_.size()) /
|
||||
(static_cast<int64_t>(row_to_element_start_.size()) - 1);
|
||||
element_offsets.reserve(selected_rows * avg_elem_per_row);
|
||||
|
||||
for (int64_t i = 0; i < row_count; ++i) {
|
||||
if (row_bitset[i]) {
|
||||
int64_t row_id = row_start + i;
|
||||
int32_t first_elem = row_to_element_start_[row_id];
|
||||
int32_t last_elem = row_to_element_start_[row_id + 1];
|
||||
for (int32_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
|
||||
element_offsets.push_back(elem_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return element_offsets;
|
||||
}
|
||||
|
||||
void
|
||||
ArrayOffsetsGrowing::Insert(int64_t row_id_start,
|
||||
const int32_t* array_lengths,
|
||||
|
||||
@ -50,10 +50,18 @@ class IArrayOffsets {
|
||||
ElementIDRangeOfRow(int32_t row_id) const = 0;
|
||||
|
||||
// Convert row-level bitsets to element-level bitsets
|
||||
// row_start: starting row index (0-based)
|
||||
// row_bitset.size(): number of rows to process
|
||||
virtual std::pair<TargetBitmap, TargetBitmap>
|
||||
RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const = 0;
|
||||
RowBitsetToElementBitset(const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset,
|
||||
int64_t row_start) const = 0;
|
||||
|
||||
// Convert row-level bitset to element offsets
|
||||
// Returns element IDs for all rows where row_bitset[row_id] is true
|
||||
virtual FixedVector<int32_t>
|
||||
RowBitsetToElementOffsets(const TargetBitmapView& row_bitset,
|
||||
int64_t row_start) const = 0;
|
||||
};
|
||||
|
||||
class ArrayOffsetsSealed : public IArrayOffsets {
|
||||
@ -93,9 +101,13 @@ class ArrayOffsetsSealed : public IArrayOffsets {
|
||||
ElementIDRangeOfRow(int32_t row_id) const override;
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const override;
|
||||
RowBitsetToElementBitset(const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset,
|
||||
int64_t row_start) const override;
|
||||
|
||||
FixedVector<int32_t>
|
||||
RowBitsetToElementOffsets(const TargetBitmapView& row_bitset,
|
||||
int64_t row_start) const override;
|
||||
|
||||
static std::shared_ptr<ArrayOffsetsSealed>
|
||||
BuildFromSegment(const void* segment, const FieldMeta& field_meta);
|
||||
@ -132,9 +144,13 @@ class ArrayOffsetsGrowing : public IArrayOffsets {
|
||||
ElementIDRangeOfRow(int32_t row_id) const override;
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
RowBitsetToElementBitset(
|
||||
const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset) const override;
|
||||
RowBitsetToElementBitset(const TargetBitmapView& row_bitset,
|
||||
const TargetBitmapView& valid_row_bitset,
|
||||
int64_t row_start) const override;
|
||||
|
||||
FixedVector<int32_t>
|
||||
RowBitsetToElementOffsets(const TargetBitmapView& row_bitset,
|
||||
int64_t row_start) const override;
|
||||
|
||||
private:
|
||||
struct PendingRow {
|
||||
|
||||
@ -114,7 +114,7 @@ TEST_F(ArrayOffsetsTest, SealedRowBitsetToElementBitset) {
|
||||
valid_row_bitset.size());
|
||||
|
||||
auto [elem_bitset, valid_elem_bitset] =
|
||||
offsets.RowBitsetToElementBitset(row_view, valid_view);
|
||||
offsets.RowBitsetToElementBitset(row_view, valid_view, 0);
|
||||
|
||||
EXPECT_EQ(elem_bitset.size(), 6);
|
||||
// Elements of row 0 (elem 0, 1) should be true
|
||||
@ -310,7 +310,7 @@ TEST_F(ArrayOffsetsTest, GrowingRowBitsetToElementBitset) {
|
||||
valid_row_bitset.size());
|
||||
|
||||
auto [elem_bitset, valid_elem_bitset] =
|
||||
offsets.RowBitsetToElementBitset(row_view, valid_view);
|
||||
offsets.RowBitsetToElementBitset(row_view, valid_view, 0);
|
||||
|
||||
EXPECT_EQ(elem_bitset.size(), 6);
|
||||
EXPECT_TRUE(elem_bitset[0]);
|
||||
|
||||
@ -79,7 +79,7 @@ ElementFilterIterator::FetchAndFilterBatch() {
|
||||
}
|
||||
|
||||
// Step 2: Batch evaluate element-level expression
|
||||
exec::EvalCtx eval_ctx(exec_context_, expr_set_, &element_ids_buffer_);
|
||||
exec::EvalCtx eval_ctx(exec_context_, &element_ids_buffer_);
|
||||
std::vector<VectorPtr> results;
|
||||
|
||||
// Evaluate the expression set (should contain only element_expr)
|
||||
|
||||
@ -27,23 +27,12 @@
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class ExprSet;
|
||||
|
||||
using OffsetVector = FixedVector<int32_t>;
|
||||
class EvalCtx {
|
||||
public:
|
||||
EvalCtx(ExecContext* exec_ctx,
|
||||
ExprSet* expr_set,
|
||||
OffsetVector* offset_input)
|
||||
: exec_ctx_(exec_ctx),
|
||||
expr_set_(expr_set),
|
||||
offset_input_(offset_input) {
|
||||
EvalCtx(ExecContext* exec_ctx, OffsetVector* offset_input)
|
||||
: exec_ctx_(exec_ctx), offset_input_(offset_input) {
|
||||
assert(exec_ctx_ != nullptr);
|
||||
assert(expr_set_ != nullptr);
|
||||
}
|
||||
|
||||
explicit EvalCtx(ExecContext* exec_ctx, ExprSet* expr_set)
|
||||
: exec_ctx_(exec_ctx), expr_set_(expr_set) {
|
||||
}
|
||||
|
||||
explicit EvalCtx(ExecContext* exec_ctx) : exec_ctx_(exec_ctx) {
|
||||
@ -96,7 +85,6 @@ class EvalCtx {
|
||||
|
||||
private:
|
||||
ExecContext* exec_ctx_ = nullptr;
|
||||
ExprSet* expr_set_ = nullptr;
|
||||
// we may accept offsets array as input and do expr filtering on these data
|
||||
OffsetVector* offset_input_ = nullptr;
|
||||
bool input_no_nulls_ = false;
|
||||
|
||||
@ -34,7 +34,11 @@ PhyExistsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
context.set_apply_valid_data_after_flip(false);
|
||||
auto input = context.get_offset_input();
|
||||
SetHasOffsetInput((input != nullptr));
|
||||
switch (expr_->column_.data_type_) {
|
||||
auto data_type = expr_->column_.data_type_;
|
||||
if (expr_->column_.element_level_) {
|
||||
data_type = expr_->column_.element_type_;
|
||||
}
|
||||
switch (data_type) {
|
||||
case DataType::JSON: {
|
||||
if (SegmentExpr::CanUseIndex() && !has_offset_input_) {
|
||||
result = EvalJsonExistsForIndex();
|
||||
@ -44,9 +48,7 @@ PhyExistsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
"unsupported data type: {}",
|
||||
expr_->column_.data_type_);
|
||||
ThrowInfo(DataTypeInvalid, "unsupported data type: {}", data_type);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#include "exec/expression/JsonContainsExpr.h"
|
||||
#include "exec/expression/LogicalBinaryExpr.h"
|
||||
#include "exec/expression/LogicalUnaryExpr.h"
|
||||
#include "exec/expression/MatchExpr.h"
|
||||
#include "exec/expression/NullExpr.h"
|
||||
#include "exec/expression/TermExpr.h"
|
||||
#include "exec/expression/UnaryExpr.h"
|
||||
@ -294,6 +295,17 @@ CompileExpression(const expr::TypedExprPtr& expr,
|
||||
context->get_active_count(),
|
||||
context->query_config()->get_expr_batch_size(),
|
||||
context->get_consistency_level());
|
||||
} else if (auto match_expr =
|
||||
std::dynamic_pointer_cast<const milvus::expr::MatchExpr>(
|
||||
expr)) {
|
||||
result = std::make_shared<PhyMatchFilterExpr>(
|
||||
compiled_inputs,
|
||||
match_expr,
|
||||
"PhyMatchFilterExpr",
|
||||
op_ctx,
|
||||
context->get_segment(),
|
||||
context->get_active_count(),
|
||||
context->query_config()->get_expr_batch_size());
|
||||
} else {
|
||||
ThrowInfo(ExprInvalid, "unsupport expr: ", expr->ToString());
|
||||
}
|
||||
|
||||
258
internal/core/src/exec/expression/MatchExpr.cpp
Normal file
258
internal/core/src/exec/expression/MatchExpr.cpp
Normal file
@ -0,0 +1,258 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "MatchExpr.h"
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include "common/Tracer.h"
|
||||
#include "common/Types.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
using MatchType = milvus::expr::MatchType;
|
||||
|
||||
template <MatchType match_type, bool all_valid>
|
||||
void
|
||||
ProcessMatchRows(int64_t row_count,
|
||||
const IArrayOffsets* array_offsets,
|
||||
const TargetBitmapView& match_bitset,
|
||||
const TargetBitmapView& valid_bitset,
|
||||
TargetBitmapView& result_bitset,
|
||||
int64_t threshold) {
|
||||
for (int64_t i = 0; i < row_count; ++i) {
|
||||
auto [first_elem, last_elem] = array_offsets->ElementIDRangeOfRow(i);
|
||||
int64_t hit_count = 0;
|
||||
int64_t element_count = last_elem - first_elem;
|
||||
bool early_fail = false;
|
||||
|
||||
if constexpr (all_valid) {
|
||||
for (auto j = first_elem; j < last_elem; ++j) {
|
||||
bool matched = match_bitset[j];
|
||||
if (matched) {
|
||||
++hit_count;
|
||||
}
|
||||
|
||||
if constexpr (match_type == MatchType::MatchAny) {
|
||||
if (hit_count > 0) {
|
||||
break;
|
||||
}
|
||||
} else if constexpr (match_type == MatchType::MatchAll) {
|
||||
if (!matched) {
|
||||
early_fail = true;
|
||||
break;
|
||||
}
|
||||
} else if constexpr (match_type == MatchType::MatchLeast) {
|
||||
if (hit_count >= threshold) {
|
||||
break;
|
||||
}
|
||||
} else if constexpr (match_type == MatchType::MatchMost ||
|
||||
match_type == MatchType::MatchExact) {
|
||||
if (hit_count > threshold) {
|
||||
early_fail = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
element_count = 0;
|
||||
for (auto j = first_elem; j < last_elem; ++j) {
|
||||
if (!valid_bitset[j]) {
|
||||
continue;
|
||||
}
|
||||
++element_count;
|
||||
bool matched = match_bitset[j];
|
||||
if (matched) {
|
||||
++hit_count;
|
||||
}
|
||||
|
||||
if constexpr (match_type == MatchType::MatchAny) {
|
||||
if (hit_count > 0) {
|
||||
break;
|
||||
}
|
||||
} else if constexpr (match_type == MatchType::MatchAll) {
|
||||
if (!matched) {
|
||||
early_fail = true;
|
||||
break;
|
||||
}
|
||||
} else if constexpr (match_type == MatchType::MatchLeast) {
|
||||
if (hit_count >= threshold) {
|
||||
break;
|
||||
}
|
||||
} else if constexpr (match_type == MatchType::MatchMost ||
|
||||
match_type == MatchType::MatchExact) {
|
||||
if (hit_count > threshold) {
|
||||
early_fail = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (early_fail) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_match = false;
|
||||
if constexpr (match_type == MatchType::MatchAny) {
|
||||
is_match = hit_count > 0;
|
||||
} else if constexpr (match_type == MatchType::MatchAll) {
|
||||
is_match = hit_count == element_count;
|
||||
} else if constexpr (match_type == MatchType::MatchLeast) {
|
||||
is_match = hit_count >= threshold;
|
||||
} else if constexpr (match_type == MatchType::MatchMost) {
|
||||
is_match = hit_count <= threshold;
|
||||
} else if constexpr (match_type == MatchType::MatchExact) {
|
||||
is_match = hit_count == threshold;
|
||||
}
|
||||
|
||||
if (is_match) {
|
||||
result_bitset[i] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <MatchType match_type>
|
||||
void
|
||||
DispatchByValidity(bool all_valid,
|
||||
int64_t row_count,
|
||||
const IArrayOffsets* array_offsets,
|
||||
const TargetBitmapView& match_bitset,
|
||||
const TargetBitmapView& valid_bitset,
|
||||
TargetBitmapView& result_bitset,
|
||||
int64_t threshold) {
|
||||
if (all_valid) {
|
||||
ProcessMatchRows<match_type, true>(row_count,
|
||||
array_offsets,
|
||||
match_bitset,
|
||||
valid_bitset,
|
||||
result_bitset,
|
||||
threshold);
|
||||
} else {
|
||||
ProcessMatchRows<match_type, false>(row_count,
|
||||
array_offsets,
|
||||
match_bitset,
|
||||
valid_bitset,
|
||||
result_bitset,
|
||||
threshold);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
PhyMatchFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
tracer::AutoSpan span("PhyMatchFilterExpr::Eval", tracer::GetRootSpan());
|
||||
|
||||
auto input = context.get_offset_input();
|
||||
AssertInfo(input == nullptr,
|
||||
"Offset input in match filter expr is not implemented now");
|
||||
|
||||
auto schema = segment_->get_schema();
|
||||
auto field_meta =
|
||||
schema.GetFirstArrayFieldInStruct(expr_->get_struct_name());
|
||||
|
||||
auto array_offsets = segment_->GetArrayOffsets(field_meta.get_id());
|
||||
AssertInfo(array_offsets != nullptr, "Array offsets not available");
|
||||
|
||||
int64_t row_count =
|
||||
context.get_exec_context()->get_query_context()->get_active_count();
|
||||
result = std::make_shared<ColumnVector>(TargetBitmap(row_count, false),
|
||||
TargetBitmap(row_count, true));
|
||||
|
||||
auto col_vec = std::dynamic_pointer_cast<ColumnVector>(result);
|
||||
AssertInfo(col_vec != nullptr, "Result should be ColumnVector");
|
||||
AssertInfo(col_vec->IsBitmap(), "Result should be bitmap");
|
||||
auto col_vec_size = col_vec->size();
|
||||
TargetBitmapView bitset_view(col_vec->GetRawData(), col_vec_size);
|
||||
|
||||
auto [total_elements, _] = array_offsets->ElementIDRangeOfRow(row_count);
|
||||
FixedVector<int32_t> element_offsets(total_elements);
|
||||
std::iota(element_offsets.begin(), element_offsets.end(), 0);
|
||||
|
||||
EvalCtx eval_ctx(context.get_exec_context(), &element_offsets);
|
||||
|
||||
VectorPtr match_result;
|
||||
// TODO(SpadeA): can be executed in batch
|
||||
inputs_[0]->Eval(eval_ctx, match_result);
|
||||
auto match_result_col_vec =
|
||||
std::dynamic_pointer_cast<ColumnVector>(match_result);
|
||||
AssertInfo(match_result_col_vec != nullptr,
|
||||
"Match result should be ColumnVector");
|
||||
AssertInfo(match_result_col_vec->IsBitmap(),
|
||||
"Match result should be bitmap");
|
||||
TargetBitmapView match_result_bitset_view(
|
||||
match_result_col_vec->GetRawData(), match_result_col_vec->size());
|
||||
TargetBitmapView match_result_valid_view(
|
||||
match_result_col_vec->GetValidRawData(), match_result_col_vec->size());
|
||||
|
||||
bool all_valid = match_result_valid_view.all();
|
||||
auto match_type = expr_->get_match_type();
|
||||
int64_t threshold = expr_->get_count();
|
||||
|
||||
switch (match_type) {
|
||||
case MatchType::MatchAny:
|
||||
DispatchByValidity<MatchType::MatchAny>(all_valid,
|
||||
row_count,
|
||||
array_offsets.get(),
|
||||
match_result_bitset_view,
|
||||
match_result_valid_view,
|
||||
bitset_view,
|
||||
threshold);
|
||||
break;
|
||||
case MatchType::MatchAll:
|
||||
DispatchByValidity<MatchType::MatchAll>(all_valid,
|
||||
row_count,
|
||||
array_offsets.get(),
|
||||
match_result_bitset_view,
|
||||
match_result_valid_view,
|
||||
bitset_view,
|
||||
threshold);
|
||||
break;
|
||||
case MatchType::MatchLeast:
|
||||
DispatchByValidity<MatchType::MatchLeast>(all_valid,
|
||||
row_count,
|
||||
array_offsets.get(),
|
||||
match_result_bitset_view,
|
||||
match_result_valid_view,
|
||||
bitset_view,
|
||||
threshold);
|
||||
break;
|
||||
case MatchType::MatchMost:
|
||||
DispatchByValidity<MatchType::MatchMost>(all_valid,
|
||||
row_count,
|
||||
array_offsets.get(),
|
||||
match_result_bitset_view,
|
||||
match_result_valid_view,
|
||||
bitset_view,
|
||||
threshold);
|
||||
break;
|
||||
case MatchType::MatchExact:
|
||||
DispatchByValidity<MatchType::MatchExact>(all_valid,
|
||||
row_count,
|
||||
array_offsets.get(),
|
||||
match_result_bitset_view,
|
||||
match_result_valid_view,
|
||||
bitset_view,
|
||||
threshold);
|
||||
break;
|
||||
default:
|
||||
ThrowInfo(OpTypeInvalid,
|
||||
"Unsupported match type: {}",
|
||||
static_cast<int>(match_type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
||||
86
internal/core/src/exec/expression/MatchExpr.h
Normal file
86
internal/core/src/exec/expression/MatchExpr.h
Normal file
@ -0,0 +1,86 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/OpContext.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
|
||||
class PhyMatchFilterExpr : public Expr {
|
||||
public:
|
||||
PhyMatchFilterExpr(
|
||||
const std::vector<std::shared_ptr<Expr>>& input,
|
||||
const std::shared_ptr<const milvus::expr::MatchExpr>& expr,
|
||||
const std::string& name,
|
||||
milvus::OpContext* op_ctx,
|
||||
const segcore::SegmentInternalInterface* segment,
|
||||
int64_t active_count,
|
||||
int64_t batch_size)
|
||||
: Expr(DataType::BOOL, std::move(input), name, op_ctx),
|
||||
expr_(expr),
|
||||
segment_(segment),
|
||||
active_count_(active_count),
|
||||
batch_size_(batch_size) {
|
||||
}
|
||||
|
||||
void
|
||||
Eval(EvalCtx& context, VectorPtr& result) override;
|
||||
|
||||
void
|
||||
MoveCursor() override {
|
||||
if (!has_offset_input_) {
|
||||
int64_t real_batch_size =
|
||||
current_pos_ + batch_size_ >= active_count_
|
||||
? active_count_ - current_pos_
|
||||
: batch_size_;
|
||||
current_pos_ += real_batch_size;
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return fmt::format("{}", expr_->ToString());
|
||||
}
|
||||
|
||||
bool
|
||||
IsSource() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<milvus::expr::ColumnInfo>
|
||||
GetColumnInfo() const override {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<const milvus::expr::MatchExpr> expr_;
|
||||
const segcore::SegmentInternalInterface* segment_;
|
||||
int64_t active_count_;
|
||||
int64_t current_pos_{0};
|
||||
int64_t batch_size_;
|
||||
};
|
||||
|
||||
} // namespace exec
|
||||
} // namespace milvus
|
||||
420
internal/core/src/exec/expression/MatchExprTest.cpp
Normal file
420
internal/core/src/exec/expression/MatchExprTest.cpp
Normal file
@ -0,0 +1,420 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
// use this file except in compliance with the License. You may obtain a copy of
|
||||
// the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations under
|
||||
// the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <boost/format.hpp>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
#include "common/Schema.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "query/Plan.h"
|
||||
#include "segcore/SegmentGrowingImpl.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
class MatchExprTest : public ::testing::Test {
|
||||
protected:
|
||||
void
|
||||
SetUp() override {
|
||||
// Create schema with struct array sub-fields
|
||||
schema_ = std::make_shared<Schema>();
|
||||
vec_fid_ = schema_->AddDebugField(
|
||||
"vec", DataType::VECTOR_FLOAT, 4, knowhere::metric::L2);
|
||||
int64_fid_ = schema_->AddDebugField("id", DataType::INT64);
|
||||
schema_->set_primary_field_id(int64_fid_);
|
||||
|
||||
sub_str_fid_ = schema_->AddDebugArrayField(
|
||||
"struct_array[sub_str]", DataType::VARCHAR, false);
|
||||
sub_int_fid_ = schema_->AddDebugArrayField(
|
||||
"struct_array[sub_int]", DataType::INT32, false);
|
||||
|
||||
// Generate test data
|
||||
GenerateTestData();
|
||||
|
||||
// Create and populate segment
|
||||
seg_ = CreateGrowingSegment(schema_, empty_index_meta);
|
||||
seg_->PreInsert(N_);
|
||||
seg_->Insert(
|
||||
0, N_, row_ids_.data(), timestamps_.data(), insert_data_.get());
|
||||
}
|
||||
|
||||
void
|
||||
GenerateTestData() {
|
||||
std::default_random_engine rng(42);
|
||||
std::vector<std::string> str_choices = {"aaa", "bbb", "ccc"};
|
||||
std::uniform_int_distribution<> str_dist(0, 2);
|
||||
std::uniform_int_distribution<> int_dist(50, 150);
|
||||
|
||||
insert_data_ = std::make_unique<InsertRecordProto>();
|
||||
|
||||
// Generate vector field
|
||||
std::vector<float> vec_data(N_ * 4);
|
||||
std::normal_distribution<float> vec_dist(0, 1);
|
||||
for (auto& v : vec_data) {
|
||||
v = vec_dist(rng);
|
||||
}
|
||||
auto vec_array = CreateDataArrayFrom(
|
||||
vec_data.data(), nullptr, N_, schema_->operator[](vec_fid_));
|
||||
insert_data_->mutable_fields_data()->AddAllocated(vec_array.release());
|
||||
|
||||
// Generate id field
|
||||
std::vector<int64_t> id_data(N_);
|
||||
for (size_t i = 0; i < N_; ++i) {
|
||||
id_data[i] = i;
|
||||
}
|
||||
auto id_array = CreateDataArrayFrom(
|
||||
id_data.data(), nullptr, N_, schema_->operator[](int64_fid_));
|
||||
insert_data_->mutable_fields_data()->AddAllocated(id_array.release());
|
||||
|
||||
// Generate struct_array[sub_str]
|
||||
sub_str_data_.resize(N_);
|
||||
for (size_t i = 0; i < N_; ++i) {
|
||||
for (int j = 0; j < array_len_; ++j) {
|
||||
sub_str_data_[i].mutable_string_data()->add_data(
|
||||
str_choices[str_dist(rng)]);
|
||||
}
|
||||
}
|
||||
auto sub_str_array =
|
||||
CreateDataArrayFrom(sub_str_data_.data(),
|
||||
nullptr,
|
||||
N_,
|
||||
schema_->operator[](sub_str_fid_));
|
||||
insert_data_->mutable_fields_data()->AddAllocated(
|
||||
sub_str_array.release());
|
||||
|
||||
// Generate struct_array[sub_int]
|
||||
sub_int_data_.resize(N_);
|
||||
for (size_t i = 0; i < N_; ++i) {
|
||||
for (int j = 0; j < array_len_; ++j) {
|
||||
sub_int_data_[i].mutable_int_data()->add_data(int_dist(rng));
|
||||
}
|
||||
}
|
||||
auto sub_int_array =
|
||||
CreateDataArrayFrom(sub_int_data_.data(),
|
||||
nullptr,
|
||||
N_,
|
||||
schema_->operator[](sub_int_fid_));
|
||||
insert_data_->mutable_fields_data()->AddAllocated(
|
||||
sub_int_array.release());
|
||||
|
||||
insert_data_->set_num_rows(N_);
|
||||
|
||||
// Generate row_ids and timestamps
|
||||
row_ids_.resize(N_);
|
||||
timestamps_.resize(N_);
|
||||
for (size_t i = 0; i < N_; ++i) {
|
||||
row_ids_[i] = i;
|
||||
timestamps_[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Count elements matching: sub_str == "aaa" && sub_int > 100
|
||||
int
|
||||
CountMatchingElements(int64_t row_idx) const {
|
||||
int count = 0;
|
||||
const auto& str_field = sub_str_data_[row_idx];
|
||||
const auto& int_field = sub_int_data_[row_idx];
|
||||
for (int j = 0; j < array_len_; ++j) {
|
||||
bool str_match = (str_field.string_data().data(j) == "aaa");
|
||||
bool int_match = (int_field.int_data().data(j) > 100);
|
||||
if (str_match && int_match) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
// Create plan with specified match type and count
|
||||
std::string
|
||||
CreatePlanText(const std::string& match_type, int64_t count) {
|
||||
return boost::str(boost::format(R"(vector_anns: <
|
||||
field_id: %1%
|
||||
predicates: <
|
||||
match_expr: <
|
||||
struct_name: "struct_array"
|
||||
match_type: %4%
|
||||
count: %5%
|
||||
predicate: <
|
||||
binary_expr: <
|
||||
op: LogicalAnd
|
||||
left: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %2%
|
||||
data_type: Array
|
||||
element_type: VarChar
|
||||
nested_path: "sub_str"
|
||||
is_element_level: true
|
||||
>
|
||||
op: Equal
|
||||
value: <
|
||||
string_val: "aaa"
|
||||
>
|
||||
>
|
||||
>
|
||||
right: <
|
||||
unary_range_expr: <
|
||||
column_info: <
|
||||
field_id: %3%
|
||||
data_type: Array
|
||||
element_type: Int32
|
||||
nested_path: "sub_int"
|
||||
is_element_level: true
|
||||
>
|
||||
op: GreaterThan
|
||||
value: <
|
||||
int64_val: 100
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
>
|
||||
query_info: <
|
||||
topk: 10
|
||||
round_decimal: 3
|
||||
metric_type: "L2"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)") % vec_fid_.get() %
|
||||
sub_str_fid_.get() % sub_int_fid_.get() % match_type %
|
||||
count);
|
||||
}
|
||||
|
||||
// Execute search and return results
|
||||
std::unique_ptr<SearchResult>
|
||||
ExecuteSearch(const std::string& raw_plan) {
|
||||
proto::plan::PlanNode plan_node;
|
||||
auto ok =
|
||||
google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
|
||||
EXPECT_TRUE(ok) << "Failed to parse plan";
|
||||
|
||||
auto plan = CreateSearchPlanFromPlanNode(schema_, plan_node);
|
||||
EXPECT_NE(plan, nullptr);
|
||||
|
||||
auto num_queries = 1;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 4, 1024);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
return seg_->Search(plan.get(), ph_group.get(), 1L << 63);
|
||||
}
|
||||
|
||||
// Verify results based on match type
|
||||
using VerifyFunc = std::function<bool(
|
||||
int match_count, int element_count, int64_t threshold)>;
|
||||
|
||||
void
|
||||
VerifyResults(const SearchResult* result,
|
||||
const std::string& match_type_name,
|
||||
int64_t threshold,
|
||||
VerifyFunc verify_func) {
|
||||
std::cout << "=== " << match_type_name << " Results ===" << std::endl;
|
||||
std::cout << "total_nq: " << result->total_nq_ << std::endl;
|
||||
std::cout << "unity_topK: " << result->unity_topK_ << std::endl;
|
||||
std::cout << "num_results: " << result->seg_offsets_.size()
|
||||
<< std::endl;
|
||||
|
||||
for (int64_t i = 0; i < result->total_nq_; ++i) {
|
||||
std::cout << "Query " << i << ":" << std::endl;
|
||||
for (int64_t k = 0; k < result->unity_topK_; ++k) {
|
||||
int64_t idx = i * result->unity_topK_ + k;
|
||||
auto offset = result->seg_offsets_[idx];
|
||||
auto distance = result->distances_[idx];
|
||||
|
||||
std::cout << " [" << k << "] offset=" << offset
|
||||
<< ", distance=" << distance;
|
||||
|
||||
if (offset >= 0 && offset < static_cast<int64_t>(N_)) {
|
||||
// Print sub_str array
|
||||
std::cout << ", sub_str=[";
|
||||
const auto& str_field = sub_str_data_[offset];
|
||||
for (int j = 0; j < str_field.string_data().data_size();
|
||||
++j) {
|
||||
if (j > 0)
|
||||
std::cout << ",";
|
||||
std::cout << str_field.string_data().data(j);
|
||||
}
|
||||
std::cout << "]";
|
||||
|
||||
// Print sub_int array
|
||||
std::cout << ", sub_int=[";
|
||||
const auto& int_field = sub_int_data_[offset];
|
||||
for (int j = 0; j < int_field.int_data().data_size(); ++j) {
|
||||
if (j > 0)
|
||||
std::cout << ",";
|
||||
std::cout << int_field.int_data().data(j);
|
||||
}
|
||||
std::cout << "]";
|
||||
|
||||
// Print match_count and verify
|
||||
int match_count = CountMatchingElements(offset);
|
||||
bool expected =
|
||||
verify_func(match_count, array_len_, threshold);
|
||||
std::cout << ", match_count=" << match_count;
|
||||
|
||||
EXPECT_TRUE(expected)
|
||||
<< match_type_name << " failed for row " << offset
|
||||
<< ": match_count=" << match_count
|
||||
<< ", element_count=" << array_len_
|
||||
<< ", threshold=" << threshold;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
std::cout << "==============================" << std::endl;
|
||||
}
|
||||
|
||||
// Member variables
|
||||
std::shared_ptr<Schema> schema_;
|
||||
FieldId vec_fid_;
|
||||
FieldId int64_fid_;
|
||||
FieldId sub_str_fid_;
|
||||
FieldId sub_int_fid_;
|
||||
|
||||
std::unique_ptr<InsertRecordProto> insert_data_;
|
||||
std::vector<milvus::proto::schema::ScalarField> sub_str_data_;
|
||||
std::vector<milvus::proto::schema::ScalarField> sub_int_data_;
|
||||
std::vector<idx_t> row_ids_;
|
||||
std::vector<Timestamp> timestamps_;
|
||||
|
||||
SegmentGrowingPtr seg_;
|
||||
|
||||
static constexpr size_t N_ = 1000;
|
||||
static constexpr int array_len_ = 5;
|
||||
};
|
||||
|
||||
TEST_F(MatchExprTest, MatchAny) {
|
||||
auto raw_plan = CreatePlanText("MatchAny", 0);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchAny",
|
||||
0,
|
||||
[](int match_count, int /*element_count*/, int64_t /*threshold*/) {
|
||||
// MatchAny: at least one element matches
|
||||
return match_count > 0;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(MatchExprTest, MatchAll) {
|
||||
auto raw_plan = CreatePlanText("MatchAll", 0);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchAll",
|
||||
0,
|
||||
[](int match_count, int element_count, int64_t /*threshold*/) {
|
||||
// MatchAll: all elements must match
|
||||
return match_count == element_count;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(MatchExprTest, MatchLeast) {
|
||||
const int64_t threshold = 3;
|
||||
auto raw_plan = CreatePlanText("MatchLeast", threshold);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchLeast(3)",
|
||||
threshold,
|
||||
[](int match_count, int /*element_count*/, int64_t threshold) {
|
||||
// MatchLeast: at least N elements match
|
||||
return match_count >= threshold;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(MatchExprTest, MatchMost) {
|
||||
const int64_t threshold = 2;
|
||||
auto raw_plan = CreatePlanText("MatchMost", threshold);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchMost(2)",
|
||||
threshold,
|
||||
[](int match_count, int /*element_count*/, int64_t threshold) {
|
||||
// MatchMost: at most N elements match
|
||||
return match_count <= threshold;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(MatchExprTest, MatchExact) {
|
||||
const int64_t threshold = 2;
|
||||
auto raw_plan = CreatePlanText("MatchExact", threshold);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchExact(2)",
|
||||
threshold,
|
||||
[](int match_count, int /*element_count*/, int64_t threshold) {
|
||||
// MatchExact: exactly N elements match
|
||||
return match_count == threshold;
|
||||
});
|
||||
}
|
||||
|
||||
// Edge case: MatchLeast with threshold = 1 (equivalent to MatchAny)
|
||||
TEST_F(MatchExprTest, MatchLeastOne) {
|
||||
const int64_t threshold = 1;
|
||||
auto raw_plan = CreatePlanText("MatchLeast", threshold);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchLeast(1)",
|
||||
threshold,
|
||||
[](int match_count, int /*element_count*/, int64_t threshold) {
|
||||
return match_count >= threshold;
|
||||
});
|
||||
}
|
||||
|
||||
// Edge case: MatchMost with threshold = 0 (no elements should match)
|
||||
TEST_F(MatchExprTest, MatchMostZero) {
|
||||
const int64_t threshold = 0;
|
||||
auto raw_plan = CreatePlanText("MatchMost", threshold);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchMost(0)",
|
||||
threshold,
|
||||
[](int match_count, int /*element_count*/, int64_t threshold) {
|
||||
return match_count <= threshold;
|
||||
});
|
||||
}
|
||||
|
||||
// Edge case: MatchExact with threshold = 0 (no elements should match)
|
||||
TEST_F(MatchExprTest, MatchExactZero) {
|
||||
const int64_t threshold = 0;
|
||||
auto raw_plan = CreatePlanText("MatchExact", threshold);
|
||||
auto result = ExecuteSearch(raw_plan);
|
||||
|
||||
VerifyResults(
|
||||
result.get(),
|
||||
"MatchExact(0)",
|
||||
threshold,
|
||||
[](int match_count, int /*element_count*/, int64_t threshold) {
|
||||
return match_count == threshold;
|
||||
});
|
||||
}
|
||||
@ -31,7 +31,11 @@ PhyNullExpr::Eval(EvalCtx& context, VectorPtr& result) {
|
||||
static_cast<int>(expr_->column_.data_type_));
|
||||
|
||||
auto input = context.get_offset_input();
|
||||
switch (expr_->column_.data_type_) {
|
||||
auto data_type = expr_->column_.data_type_;
|
||||
if (expr_->column_.element_level_) {
|
||||
data_type = expr_->column_.element_type_;
|
||||
}
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
result = ExecVisitorImpl<bool>(input);
|
||||
break;
|
||||
|
||||
@ -88,7 +88,7 @@ PhyElementFilterBitsNode::GetOutput() {
|
||||
|
||||
// Step 3: Convert doc bitset to element offsets
|
||||
FixedVector<int32_t> element_offsets =
|
||||
DocBitsetToElementOffsets(doc_bitset);
|
||||
array_offsets->RowBitsetToElementOffsets(doc_bitset, 0);
|
||||
|
||||
// Step 4: Evaluate element expression
|
||||
auto [expr_result, valid_expr_result] =
|
||||
@ -122,38 +122,6 @@ PhyElementFilterBitsNode::GetOutput() {
|
||||
return std::make_shared<RowVector>(col_res);
|
||||
}
|
||||
|
||||
FixedVector<int32_t>
|
||||
PhyElementFilterBitsNode::DocBitsetToElementOffsets(
|
||||
const TargetBitmapView& doc_bitset) {
|
||||
auto array_offsets = query_context_->get_array_offsets();
|
||||
AssertInfo(array_offsets != nullptr, "Array offsets not available");
|
||||
|
||||
int64_t doc_count = array_offsets->GetRowCount();
|
||||
AssertInfo(doc_bitset.size() == doc_count,
|
||||
"Doc bitset size mismatch: {} vs {}",
|
||||
doc_bitset.size(),
|
||||
doc_count);
|
||||
|
||||
FixedVector<int32_t> element_offsets;
|
||||
element_offsets.reserve(array_offsets->GetTotalElementCount());
|
||||
|
||||
// For each document that passes the filter, get all its element offsets
|
||||
for (int64_t doc_id = 0; doc_id < doc_count; ++doc_id) {
|
||||
if (doc_bitset[doc_id]) {
|
||||
// Get element range for this document
|
||||
auto [first_elem, last_elem] =
|
||||
array_offsets->ElementIDRangeOfRow(doc_id);
|
||||
|
||||
// Add all element IDs for this document
|
||||
for (int64_t elem_id = first_elem; elem_id < last_elem; ++elem_id) {
|
||||
element_offsets.push_back(static_cast<int32_t>(elem_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return element_offsets;
|
||||
}
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
PhyElementFilterBitsNode::EvaluateElementExpression(
|
||||
FixedVector<int32_t>& element_offsets) {
|
||||
@ -162,10 +130,8 @@ PhyElementFilterBitsNode::EvaluateElementExpression(
|
||||
true);
|
||||
tracer::AddEvent(fmt::format("input_elements: {}", element_offsets.size()));
|
||||
|
||||
// Use offset interface by passing element_offsets as third parameter
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context(),
|
||||
element_exprs_.get(),
|
||||
&element_offsets);
|
||||
// Use offset interface by passing element_offsets
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context(), &element_offsets);
|
||||
|
||||
std::vector<VectorPtr> results;
|
||||
element_exprs_->Eval(0, 1, true, eval_ctx, results);
|
||||
|
||||
@ -77,9 +77,6 @@ class PhyElementFilterBitsNode : public Operator {
|
||||
}
|
||||
|
||||
private:
|
||||
FixedVector<int32_t>
|
||||
DocBitsetToElementOffsets(const TargetBitmapView& doc_bitset);
|
||||
|
||||
std::pair<TargetBitmap, TargetBitmap>
|
||||
EvaluateElementExpression(FixedVector<int32_t>& element_offsets);
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ PhyFilterBitsNode::GetOutput() {
|
||||
std::chrono::high_resolution_clock::time_point scalar_start =
|
||||
std::chrono::high_resolution_clock::now();
|
||||
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context(), exprs_.get());
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context());
|
||||
|
||||
TargetBitmap bitset;
|
||||
TargetBitmap valid_bitset;
|
||||
|
||||
@ -160,7 +160,7 @@ PhyIterativeFilterNode::GetOutput() {
|
||||
TargetBitmap bitset;
|
||||
// get bitset of whole segment first
|
||||
if (!is_native_supported_) {
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context(), exprs_.get());
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context());
|
||||
|
||||
TargetBitmap valid_bitset;
|
||||
while (num_processed_rows_ < need_process_rows_) {
|
||||
@ -225,8 +225,7 @@ PhyIterativeFilterNode::GetOutput() {
|
||||
std::unordered_set<int64_t> unique_doc_ids;
|
||||
|
||||
for (auto& iterator : search_result.vector_iterators_.value()) {
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context(),
|
||||
exprs_.get());
|
||||
EvalCtx eval_ctx(operator_context_->get_exec_context());
|
||||
int topk = 0;
|
||||
while (iterator->HasNext() && topk < unity_topk) {
|
||||
FixedVector<int32_t> offsets;
|
||||
|
||||
@ -111,7 +111,7 @@ PhyRescoresNode::GetOutput() {
|
||||
filters.emplace_back(filter);
|
||||
auto expr_set = std::make_unique<ExprSet>(filters, exec_context);
|
||||
std::vector<VectorPtr> results;
|
||||
EvalCtx eval_ctx(exec_context, expr_set.get());
|
||||
EvalCtx eval_ctx(exec_context);
|
||||
|
||||
const auto& exprs = expr_set->exprs();
|
||||
bool is_native_supported = true;
|
||||
|
||||
@ -104,7 +104,7 @@ PhyVectorSearchNode::GetOutput() {
|
||||
col_input->size());
|
||||
|
||||
auto [element_bitset, valid_element_bitset] =
|
||||
array_offsets->RowBitsetToElementBitset(view, valid_view);
|
||||
array_offsets->RowBitsetToElementBitset(view, valid_view, 0);
|
||||
|
||||
query_context_->set_active_element_count(element_bitset.size());
|
||||
|
||||
|
||||
@ -890,6 +890,49 @@ class JsonContainsExpr : public ITypeFilterExpr {
|
||||
bool same_type_;
|
||||
const std::vector<proto::plan::GenericValue> vals_;
|
||||
};
|
||||
|
||||
// MatchType mirrors the protobuf MatchType enum
|
||||
using MatchType = proto::plan::MatchType;
|
||||
|
||||
class MatchExpr : public ITypeFilterExpr {
|
||||
public:
|
||||
MatchExpr(const std::string& struct_name,
|
||||
MatchType match_type,
|
||||
int64_t count,
|
||||
const TypedExprPtr& predicate)
|
||||
: struct_name_(struct_name), match_type_(match_type), count_(count) {
|
||||
inputs_.push_back(predicate);
|
||||
}
|
||||
|
||||
std::string
|
||||
ToString() const override {
|
||||
return fmt::format("MatchExpr(struct_name={}, match_type={}, count={})",
|
||||
struct_name_,
|
||||
proto::plan::MatchType_Name(match_type_),
|
||||
count_);
|
||||
}
|
||||
|
||||
const std::string&
|
||||
get_struct_name() const {
|
||||
return struct_name_;
|
||||
}
|
||||
|
||||
MatchType
|
||||
get_match_type() const {
|
||||
return match_type_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_count() const {
|
||||
return count_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string struct_name_;
|
||||
MatchType match_type_;
|
||||
int64_t count_; // Used for MatchLeast/MatchMost/MatchExact
|
||||
};
|
||||
|
||||
} // namespace expr
|
||||
} // namespace milvus
|
||||
|
||||
|
||||
@ -445,7 +445,7 @@ ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
static_cast<DataType>(column_info.element_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
@ -470,7 +470,7 @@ ProtoParser::ParseNullExprs(const proto::plan::NullExpr& expr_pb) {
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
static_cast<DataType>(column_info.element_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
@ -488,7 +488,7 @@ ProtoParser::ParseBinaryRangeExprs(
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
@ -510,7 +510,7 @@ ProtoParser::ParseTimestamptzArithCompareExprs(
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
@ -533,6 +533,16 @@ ProtoParser::ParseElementFilterExprs(
|
||||
"ElementFilterExpr must be handled at PlanNode level.");
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseMatchExprs(const proto::plan::MatchExpr& expr_pb) {
|
||||
auto struct_name = expr_pb.struct_name();
|
||||
auto match_type = expr_pb.match_type();
|
||||
auto count = expr_pb.count();
|
||||
auto predicate = this->ParseExprs(expr_pb.predicate());
|
||||
return std::make_shared<expr::MatchExpr>(
|
||||
struct_name, match_type, count, predicate);
|
||||
}
|
||||
|
||||
expr::TypedExprPtr
|
||||
ProtoParser::ParseCallExprs(const proto::plan::CallExpr& expr_pb) {
|
||||
std::vector<expr::TypedExprPtr> parameters;
|
||||
@ -566,7 +576,7 @@ ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
|
||||
if (left_column_info.is_element_level()) {
|
||||
Assert(left_data_type == DataType::ARRAY);
|
||||
Assert(left_field.get_element_type() ==
|
||||
static_cast<DataType>(left_column_info.data_type()));
|
||||
static_cast<DataType>(left_column_info.element_type()));
|
||||
} else {
|
||||
Assert(left_data_type ==
|
||||
static_cast<DataType>(left_column_info.data_type()));
|
||||
@ -580,7 +590,7 @@ ProtoParser::ParseCompareExprs(const proto::plan::CompareExpr& expr_pb) {
|
||||
if (right_column_info.is_element_level()) {
|
||||
Assert(right_data_type == DataType::ARRAY);
|
||||
Assert(right_field.get_element_type() ==
|
||||
static_cast<DataType>(right_column_info.data_type()));
|
||||
static_cast<DataType>(right_column_info.element_type()));
|
||||
} else {
|
||||
Assert(right_data_type ==
|
||||
static_cast<DataType>(right_column_info.data_type()));
|
||||
@ -602,7 +612,7 @@ ProtoParser::ParseTermExprs(const proto::plan::TermExpr& expr_pb) {
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
@ -641,7 +651,7 @@ ProtoParser::ParseBinaryArithOpEvalRangeExprs(
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
static_cast<DataType>(column_info.element_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
@ -663,7 +673,7 @@ ProtoParser::ParseExistExprs(const proto::plan::ExistsExpr& expr_pb) {
|
||||
if (column_info.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() ==
|
||||
static_cast<DataType>(column_info.data_type()));
|
||||
static_cast<DataType>(column_info.element_type()));
|
||||
} else {
|
||||
Assert(data_type == static_cast<DataType>(column_info.data_type()));
|
||||
}
|
||||
@ -680,7 +690,7 @@ ProtoParser::ParseJsonContainsExprs(
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
@ -715,7 +725,7 @@ ProtoParser::ParseGISFunctionFilterExprs(
|
||||
|
||||
if (columnInfo.is_element_level()) {
|
||||
Assert(data_type == DataType::ARRAY);
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.data_type());
|
||||
Assert(field.get_element_type() == (DataType)columnInfo.element_type());
|
||||
} else {
|
||||
Assert(data_type == (DataType)columnInfo.data_type());
|
||||
}
|
||||
@ -809,6 +819,10 @@ ProtoParser::ParseExprs(const proto::plan::Expr& expr_pb,
|
||||
"ElementFilterExpr should be handled at PlanNode level, "
|
||||
"not in ParseExprs");
|
||||
}
|
||||
case ppe::kMatchExpr: {
|
||||
result = ParseMatchExprs(expr_pb.match_expr());
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
std::string s;
|
||||
google::protobuf::TextFormat::PrintToString(expr_pb, &s);
|
||||
|
||||
@ -109,6 +109,9 @@ class ProtoParser {
|
||||
expr::TypedExprPtr
|
||||
ParseElementFilterExprs(const proto::plan::ElementFilterExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseMatchExprs(const proto::plan::MatchExpr& expr_pb);
|
||||
|
||||
expr::TypedExprPtr
|
||||
ParseValueExprs(const proto::plan::ValueExpr& expr_pb);
|
||||
|
||||
|
||||
@ -166,8 +166,7 @@ gen_filter_res(milvus::plan::PlanNode* plan_node,
|
||||
auto exprs_ =
|
||||
std::make_unique<milvus::exec::ExprSet>(filters, exec_context.get());
|
||||
std::vector<VectorPtr> results_;
|
||||
milvus::exec::EvalCtx eval_ctx(exec_context.get(), exprs_.get());
|
||||
eval_ctx.set_offset_input(offsets);
|
||||
milvus::exec::EvalCtx eval_ctx(exec_context.get(), offsets);
|
||||
exprs_->Eval(0, 1, true, eval_ctx, results_);
|
||||
|
||||
auto col_vec = std::dynamic_pointer_cast<milvus::ColumnVector>(results_[0]);
|
||||
|
||||
@ -20,6 +20,11 @@ expr:
|
||||
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
|
||||
| RANDOMSAMPLE'(' expr ')' # RandomSample
|
||||
| ElementFilter'('Identifier',' expr')' # ElementFilter
|
||||
| MATCH_ALL'(' Identifier ',' expr ')' # MatchAll
|
||||
| MATCH_ANY'(' Identifier ',' expr ')' # MatchAny
|
||||
| MATCH_LEAST'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchLeast
|
||||
| MATCH_MOST'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchMost
|
||||
| MATCH_EXACT'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchExact
|
||||
| expr POW expr # Power
|
||||
| op = (ADD | SUB | BNOT | NOT) expr # Unary
|
||||
// | '(' typeName ')' expr # Cast
|
||||
@ -80,9 +85,15 @@ EXISTS: 'exists' | 'EXISTS';
|
||||
TEXTMATCH: 'text_match'|'TEXT_MATCH';
|
||||
PHRASEMATCH: 'phrase_match'|'PHRASE_MATCH';
|
||||
RANDOMSAMPLE: 'random_sample' | 'RANDOM_SAMPLE';
|
||||
MATCH_ALL: 'match_all' | 'MATCH_ALL';
|
||||
MATCH_ANY: 'match_any' | 'MATCH_ANY';
|
||||
MATCH_LEAST: 'match_least' | 'MATCH_LEAST';
|
||||
MATCH_MOST: 'match_most' | 'MATCH_MOST';
|
||||
MATCH_EXACT: 'match_exact' | 'MATCH_EXACT';
|
||||
INTERVAL: 'interval' | 'INTERVAL';
|
||||
ISO: 'iso' | 'ISO';
|
||||
MINIMUM_SHOULD_MATCH: 'minimum_should_match' | 'MINIMUM_SHOULD_MATCH';
|
||||
THRESHOLD: 'threshold' | 'THRESHOLD';
|
||||
ASSIGN: '=';
|
||||
|
||||
ADD: '+';
|
||||
|
||||
@ -64,7 +64,8 @@ func FillTermExpressionValue(expr *planpb.TermExpr, templateValues map[string]*p
|
||||
}
|
||||
dataType := expr.GetColumnInfo().GetDataType()
|
||||
if typeutil.IsArrayType(dataType) {
|
||||
if len(expr.GetColumnInfo().GetNestedPath()) != 0 {
|
||||
// Use element type if accessing array element
|
||||
if len(expr.GetColumnInfo().GetNestedPath()) != 0 || expr.GetColumnInfo().GetIsElementLevel() {
|
||||
dataType = expr.GetColumnInfo().GetElementType()
|
||||
}
|
||||
}
|
||||
@ -91,7 +92,8 @@ func FillUnaryRangeExpressionValue(expr *planpb.UnaryRangeExpr, templateValues m
|
||||
|
||||
dataType := expr.GetColumnInfo().GetDataType()
|
||||
if typeutil.IsArrayType(dataType) {
|
||||
if len(expr.GetColumnInfo().GetNestedPath()) != 0 {
|
||||
// Use element type if accessing array element
|
||||
if len(expr.GetColumnInfo().GetNestedPath()) != 0 || expr.GetColumnInfo().GetIsElementLevel() {
|
||||
dataType = expr.GetColumnInfo().GetElementType()
|
||||
}
|
||||
}
|
||||
@ -107,7 +109,8 @@ func FillUnaryRangeExpressionValue(expr *planpb.UnaryRangeExpr, templateValues m
|
||||
func FillBinaryRangeExpressionValue(expr *planpb.BinaryRangeExpr, templateValues map[string]*planpb.GenericValue) error {
|
||||
var ok bool
|
||||
dataType := expr.GetColumnInfo().GetDataType()
|
||||
if typeutil.IsArrayType(dataType) && len(expr.GetColumnInfo().GetNestedPath()) != 0 {
|
||||
// Use element type if accessing array element
|
||||
if typeutil.IsArrayType(dataType) && (len(expr.GetColumnInfo().GetNestedPath()) != 0 || expr.GetColumnInfo().GetIsElementLevel()) {
|
||||
dataType = expr.GetColumnInfo().GetElementType()
|
||||
}
|
||||
lowerValue := expr.GetLowerValue()
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -16,56 +16,62 @@ EXISTS=15
|
||||
TEXTMATCH=16
|
||||
PHRASEMATCH=17
|
||||
RANDOMSAMPLE=18
|
||||
INTERVAL=19
|
||||
ISO=20
|
||||
MINIMUM_SHOULD_MATCH=21
|
||||
ASSIGN=22
|
||||
ADD=23
|
||||
SUB=24
|
||||
MUL=25
|
||||
DIV=26
|
||||
MOD=27
|
||||
POW=28
|
||||
SHL=29
|
||||
SHR=30
|
||||
BAND=31
|
||||
BOR=32
|
||||
BXOR=33
|
||||
AND=34
|
||||
OR=35
|
||||
ISNULL=36
|
||||
ISNOTNULL=37
|
||||
BNOT=38
|
||||
NOT=39
|
||||
IN=40
|
||||
EmptyArray=41
|
||||
JSONContains=42
|
||||
JSONContainsAll=43
|
||||
JSONContainsAny=44
|
||||
ArrayContains=45
|
||||
ArrayContainsAll=46
|
||||
ArrayContainsAny=47
|
||||
ArrayLength=48
|
||||
ElementFilter=49
|
||||
STEuqals=50
|
||||
STTouches=51
|
||||
STOverlaps=52
|
||||
STCrosses=53
|
||||
STContains=54
|
||||
STIntersects=55
|
||||
STWithin=56
|
||||
STDWithin=57
|
||||
STIsValid=58
|
||||
BooleanConstant=59
|
||||
IntegerConstant=60
|
||||
FloatingConstant=61
|
||||
Identifier=62
|
||||
Meta=63
|
||||
StringLiteral=64
|
||||
JSONIdentifier=65
|
||||
StructSubFieldIdentifier=66
|
||||
Whitespace=67
|
||||
Newline=68
|
||||
MATCH_ALL=19
|
||||
MATCH_ANY=20
|
||||
MATCH_LEAST=21
|
||||
MATCH_MOST=22
|
||||
MATCH_EXACT=23
|
||||
INTERVAL=24
|
||||
ISO=25
|
||||
MINIMUM_SHOULD_MATCH=26
|
||||
THRESHOLD=27
|
||||
ASSIGN=28
|
||||
ADD=29
|
||||
SUB=30
|
||||
MUL=31
|
||||
DIV=32
|
||||
MOD=33
|
||||
POW=34
|
||||
SHL=35
|
||||
SHR=36
|
||||
BAND=37
|
||||
BOR=38
|
||||
BXOR=39
|
||||
AND=40
|
||||
OR=41
|
||||
ISNULL=42
|
||||
ISNOTNULL=43
|
||||
BNOT=44
|
||||
NOT=45
|
||||
IN=46
|
||||
EmptyArray=47
|
||||
JSONContains=48
|
||||
JSONContainsAll=49
|
||||
JSONContainsAny=50
|
||||
ArrayContains=51
|
||||
ArrayContainsAll=52
|
||||
ArrayContainsAny=53
|
||||
ArrayLength=54
|
||||
ElementFilter=55
|
||||
STEuqals=56
|
||||
STTouches=57
|
||||
STOverlaps=58
|
||||
STCrosses=59
|
||||
STContains=60
|
||||
STIntersects=61
|
||||
STWithin=62
|
||||
STDWithin=63
|
||||
STIsValid=64
|
||||
BooleanConstant=65
|
||||
IntegerConstant=66
|
||||
FloatingConstant=67
|
||||
Identifier=68
|
||||
Meta=69
|
||||
StringLiteral=70
|
||||
JSONIdentifier=71
|
||||
StructSubFieldIdentifier=72
|
||||
Whitespace=73
|
||||
Newline=74
|
||||
'('=1
|
||||
')'=2
|
||||
'['=3
|
||||
@ -79,17 +85,17 @@ Newline=68
|
||||
'>='=11
|
||||
'=='=12
|
||||
'!='=13
|
||||
'='=22
|
||||
'+'=23
|
||||
'-'=24
|
||||
'*'=25
|
||||
'/'=26
|
||||
'%'=27
|
||||
'**'=28
|
||||
'<<'=29
|
||||
'>>'=30
|
||||
'&'=31
|
||||
'|'=32
|
||||
'^'=33
|
||||
'~'=38
|
||||
'$meta'=63
|
||||
'='=28
|
||||
'+'=29
|
||||
'-'=30
|
||||
'*'=31
|
||||
'/'=32
|
||||
'%'=33
|
||||
'**'=34
|
||||
'<<'=35
|
||||
'>>'=36
|
||||
'&'=37
|
||||
'|'=38
|
||||
'^'=39
|
||||
'~'=44
|
||||
'$meta'=69
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -16,56 +16,62 @@ EXISTS=15
|
||||
TEXTMATCH=16
|
||||
PHRASEMATCH=17
|
||||
RANDOMSAMPLE=18
|
||||
INTERVAL=19
|
||||
ISO=20
|
||||
MINIMUM_SHOULD_MATCH=21
|
||||
ASSIGN=22
|
||||
ADD=23
|
||||
SUB=24
|
||||
MUL=25
|
||||
DIV=26
|
||||
MOD=27
|
||||
POW=28
|
||||
SHL=29
|
||||
SHR=30
|
||||
BAND=31
|
||||
BOR=32
|
||||
BXOR=33
|
||||
AND=34
|
||||
OR=35
|
||||
ISNULL=36
|
||||
ISNOTNULL=37
|
||||
BNOT=38
|
||||
NOT=39
|
||||
IN=40
|
||||
EmptyArray=41
|
||||
JSONContains=42
|
||||
JSONContainsAll=43
|
||||
JSONContainsAny=44
|
||||
ArrayContains=45
|
||||
ArrayContainsAll=46
|
||||
ArrayContainsAny=47
|
||||
ArrayLength=48
|
||||
ElementFilter=49
|
||||
STEuqals=50
|
||||
STTouches=51
|
||||
STOverlaps=52
|
||||
STCrosses=53
|
||||
STContains=54
|
||||
STIntersects=55
|
||||
STWithin=56
|
||||
STDWithin=57
|
||||
STIsValid=58
|
||||
BooleanConstant=59
|
||||
IntegerConstant=60
|
||||
FloatingConstant=61
|
||||
Identifier=62
|
||||
Meta=63
|
||||
StringLiteral=64
|
||||
JSONIdentifier=65
|
||||
StructSubFieldIdentifier=66
|
||||
Whitespace=67
|
||||
Newline=68
|
||||
MATCH_ALL=19
|
||||
MATCH_ANY=20
|
||||
MATCH_LEAST=21
|
||||
MATCH_MOST=22
|
||||
MATCH_EXACT=23
|
||||
INTERVAL=24
|
||||
ISO=25
|
||||
MINIMUM_SHOULD_MATCH=26
|
||||
THRESHOLD=27
|
||||
ASSIGN=28
|
||||
ADD=29
|
||||
SUB=30
|
||||
MUL=31
|
||||
DIV=32
|
||||
MOD=33
|
||||
POW=34
|
||||
SHL=35
|
||||
SHR=36
|
||||
BAND=37
|
||||
BOR=38
|
||||
BXOR=39
|
||||
AND=40
|
||||
OR=41
|
||||
ISNULL=42
|
||||
ISNOTNULL=43
|
||||
BNOT=44
|
||||
NOT=45
|
||||
IN=46
|
||||
EmptyArray=47
|
||||
JSONContains=48
|
||||
JSONContainsAll=49
|
||||
JSONContainsAny=50
|
||||
ArrayContains=51
|
||||
ArrayContainsAll=52
|
||||
ArrayContainsAny=53
|
||||
ArrayLength=54
|
||||
ElementFilter=55
|
||||
STEuqals=56
|
||||
STTouches=57
|
||||
STOverlaps=58
|
||||
STCrosses=59
|
||||
STContains=60
|
||||
STIntersects=61
|
||||
STWithin=62
|
||||
STDWithin=63
|
||||
STIsValid=64
|
||||
BooleanConstant=65
|
||||
IntegerConstant=66
|
||||
FloatingConstant=67
|
||||
Identifier=68
|
||||
Meta=69
|
||||
StringLiteral=70
|
||||
JSONIdentifier=71
|
||||
StructSubFieldIdentifier=72
|
||||
Whitespace=73
|
||||
Newline=74
|
||||
'('=1
|
||||
')'=2
|
||||
'['=3
|
||||
@ -79,17 +85,17 @@ Newline=68
|
||||
'>='=11
|
||||
'=='=12
|
||||
'!='=13
|
||||
'='=22
|
||||
'+'=23
|
||||
'-'=24
|
||||
'*'=25
|
||||
'/'=26
|
||||
'%'=27
|
||||
'**'=28
|
||||
'<<'=29
|
||||
'>>'=30
|
||||
'&'=31
|
||||
'|'=32
|
||||
'^'=33
|
||||
'~'=38
|
||||
'$meta'=63
|
||||
'='=28
|
||||
'+'=29
|
||||
'-'=30
|
||||
'*'=31
|
||||
'/'=32
|
||||
'%'=33
|
||||
'**'=34
|
||||
'<<'=35
|
||||
'>>'=36
|
||||
'&'=37
|
||||
'|'=38
|
||||
'^'=39
|
||||
'~'=44
|
||||
'$meta'=69
|
||||
|
||||
@ -11,6 +11,10 @@ func (v *BasePlanVisitor) VisitString(ctx *StringContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitMatchAny(ctx *MatchAnyContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitFloating(ctx *FloatingContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -59,6 +63,10 @@ func (v *BasePlanVisitor) VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitMatchLeast(ctx *MatchLeastContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitArrayLength(ctx *ArrayLengthContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -83,6 +91,10 @@ func (v *BasePlanVisitor) VisitRange(ctx *RangeContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitMatchAll(ctx *MatchAllContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -103,6 +115,10 @@ func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitMatchMost(ctx *MatchMostContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
@ -115,6 +131,10 @@ func (v *BasePlanVisitor) VisitParens(ctx *ParensContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitMatchExact(ctx *MatchExactContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -10,6 +10,9 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#String.
|
||||
VisitString(ctx *StringContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#MatchAny.
|
||||
VisitMatchAny(ctx *MatchAnyContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#Floating.
|
||||
VisitFloating(ctx *FloatingContext) interface{}
|
||||
|
||||
@ -46,6 +49,9 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#PhraseMatch.
|
||||
VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#MatchLeast.
|
||||
VisitMatchLeast(ctx *MatchLeastContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#ArrayLength.
|
||||
VisitArrayLength(ctx *ArrayLengthContext) interface{}
|
||||
|
||||
@ -64,6 +70,9 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#Range.
|
||||
VisitRange(ctx *RangeContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#MatchAll.
|
||||
VisitMatchAll(ctx *MatchAllContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STIsValid.
|
||||
VisitSTIsValid(ctx *STIsValidContext) interface{}
|
||||
|
||||
@ -79,6 +88,9 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#STOverlaps.
|
||||
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#MatchMost.
|
||||
VisitMatchMost(ctx *MatchMostContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#JSONIdentifier.
|
||||
VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{}
|
||||
|
||||
@ -88,6 +100,9 @@ type PlanVisitor interface {
|
||||
// Visit a parse tree produced by PlanParser#Parens.
|
||||
VisitParens(ctx *ParensContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#MatchExact.
|
||||
VisitMatchExact(ctx *MatchExactContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#JSONContainsAll.
|
||||
VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{}
|
||||
|
||||
|
||||
@ -192,7 +192,7 @@ func (v *ParserVisitor) VisitString(ctx *parser.StringContext) interface{} {
|
||||
}
|
||||
|
||||
func checkDirectComparisonBinaryField(columnInfo *planpb.ColumnInfo) error {
|
||||
if typeutil.IsArrayType(columnInfo.GetDataType()) && len(columnInfo.GetNestedPath()) == 0 {
|
||||
if typeutil.IsArrayType(columnInfo.GetDataType()) && len(columnInfo.GetNestedPath()) == 0 && !columnInfo.GetIsElementLevel() {
|
||||
return errors.New("can not comparisons array fields directly")
|
||||
}
|
||||
return nil
|
||||
@ -664,6 +664,10 @@ func isElementFilterExpr(expr *ExprWithType) bool {
|
||||
return expr.expr.GetElementFilterExpr() != nil
|
||||
}
|
||||
|
||||
func isMatchExpr(expr *ExprWithType) bool {
|
||||
return expr.expr.GetMatchExpr() != nil
|
||||
}
|
||||
|
||||
const EPSILON = 1e-10
|
||||
|
||||
func (v *ParserVisitor) VisitRandomSample(ctx *parser.RandomSampleContext) interface{} {
|
||||
@ -715,7 +719,10 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} {
|
||||
}
|
||||
|
||||
dataType := columnInfo.GetDataType()
|
||||
if typeutil.IsArrayType(dataType) && len(columnInfo.GetNestedPath()) != 0 {
|
||||
// Use element type for IN operation in two cases:
|
||||
// 1. Array with nested path (e.g., arr[0] IN [1, 2, 3])
|
||||
// 2. Array with element level flag (e.g., $[intField] IN [1, 2] in MATCH_ALL/ElementFilter)
|
||||
if typeutil.IsArrayType(dataType) && (len(columnInfo.GetNestedPath()) != 0 || columnInfo.GetIsElementLevel()) {
|
||||
dataType = columnInfo.GetElementType()
|
||||
}
|
||||
|
||||
@ -2298,9 +2305,9 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
|
||||
// Remove "$[" prefix and "]" suffix
|
||||
fieldName := tokenText[2 : len(tokenText)-1]
|
||||
|
||||
// Check if we're inside an ElementFilter context
|
||||
// Check if we're inside an ElementFilter or MATCH_* context
|
||||
if v.currentStructArrayField == "" {
|
||||
return fmt.Errorf("$[%s] syntax can only be used inside ElementFilter", fieldName)
|
||||
return fmt.Errorf("$[%s] syntax can only be used inside ElementFilter or MATCH_*", fieldName)
|
||||
}
|
||||
|
||||
// Construct full field name for struct array field
|
||||
@ -2311,7 +2318,7 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
|
||||
return fmt.Errorf("array field not found: %s, error: %s", fullFieldName, err)
|
||||
}
|
||||
|
||||
// In element-level context, data_type should be the element type
|
||||
// In element-level context, use Array as storage type, element type for operations
|
||||
elementType := field.GetElementType()
|
||||
|
||||
return &ExprWithType{
|
||||
@ -2320,12 +2327,12 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
|
||||
ColumnExpr: &planpb.ColumnExpr{
|
||||
Info: &planpb.ColumnInfo{
|
||||
FieldId: field.FieldID,
|
||||
DataType: elementType, // Use element type, not storage type
|
||||
DataType: schemapb.DataType_Array, // Storage type is Array
|
||||
IsPrimaryKey: field.IsPrimaryKey,
|
||||
IsAutoID: field.AutoID,
|
||||
IsPartitionKey: field.IsPartitionKey,
|
||||
IsClusteringKey: field.IsClusteringKey,
|
||||
ElementType: elementType,
|
||||
ElementType: elementType, // Element type for operations
|
||||
Nullable: field.GetNullable(),
|
||||
IsElementLevel: true, // Mark as element-level access
|
||||
},
|
||||
@ -2336,3 +2343,113 @@ func (v *ParserVisitor) VisitStructSubField(ctx *parser.StructSubFieldContext) i
|
||||
nodeDependent: true,
|
||||
}
|
||||
}
|
||||
|
||||
// parseMatchExpr is a helper function for parsing match expressions
|
||||
// matchType: the type of match operation (MatchAll, MatchAny, MatchLeast, MatchMost)
|
||||
// count: for MatchLeast/MatchMost, the count parameter (N); for MatchAll/MatchAny, this is ignored (0)
|
||||
func (v *ParserVisitor) parseMatchExpr(structArrayFieldName string, exprCtx parser.IExprContext, matchType planpb.MatchType, count int64, funcName string) interface{} {
|
||||
// Check for nested match expression - not allowed
|
||||
if v.currentStructArrayField != "" {
|
||||
return fmt.Errorf("nested %s is not supported, already inside match expression for field: %s", funcName, v.currentStructArrayField)
|
||||
}
|
||||
|
||||
// Set current context for element expression parsing
|
||||
v.currentStructArrayField = structArrayFieldName
|
||||
defer func() { v.currentStructArrayField = "" }()
|
||||
|
||||
// Parse the predicate expression
|
||||
predicate := exprCtx.Accept(v)
|
||||
if err := getError(predicate); err != nil {
|
||||
return fmt.Errorf("cannot parse predicate expression: %s, error: %s", exprCtx.GetText(), err)
|
||||
}
|
||||
|
||||
predicateExpr := getExpr(predicate)
|
||||
if predicateExpr == nil {
|
||||
return fmt.Errorf("invalid predicate expression in %s: %s", funcName, exprCtx.GetText())
|
||||
}
|
||||
|
||||
// Build MatchExpr proto
|
||||
return &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_MatchExpr{
|
||||
MatchExpr: &planpb.MatchExpr{
|
||||
StructName: structArrayFieldName,
|
||||
Predicate: predicateExpr.expr,
|
||||
MatchType: matchType,
|
||||
Count: count,
|
||||
},
|
||||
},
|
||||
},
|
||||
dataType: schemapb.DataType_Bool,
|
||||
}
|
||||
}
|
||||
|
||||
// VisitMatchAll handles MATCH_ALL expressions
|
||||
// Syntax: MATCH_ALL(structArrayField, $[intField] == 1 && $[strField] == "aaa")
|
||||
// All elements must match the predicate
|
||||
func (v *ParserVisitor) VisitMatchAll(ctx *parser.MatchAllContext) interface{} {
|
||||
structArrayFieldName := ctx.Identifier().GetText()
|
||||
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchAll, 0, "MATCH_ALL")
|
||||
}
|
||||
|
||||
// VisitMatchAny handles MATCH_ANY expressions
|
||||
// Syntax: MATCH_ANY(structArrayField, $[intField] == 1 && $[strField] == "aaa")
|
||||
// At least one element must match the predicate
|
||||
func (v *ParserVisitor) VisitMatchAny(ctx *parser.MatchAnyContext) interface{} {
|
||||
structArrayFieldName := ctx.Identifier().GetText()
|
||||
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchAny, 0, "MATCH_ANY")
|
||||
}
|
||||
|
||||
// VisitMatchLeast handles MATCH_LEAST expressions
|
||||
// Syntax: MATCH_LEAST(structArrayField, $[intField] == 1 && $[strField] == "aaa", threshold=N)
|
||||
// At least N elements must match the predicate
|
||||
func (v *ParserVisitor) VisitMatchLeast(ctx *parser.MatchLeastContext) interface{} {
|
||||
structArrayFieldName := ctx.Identifier().GetText()
|
||||
|
||||
countStr := ctx.IntegerConstant().GetText()
|
||||
count, err := strconv.ParseInt(countStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid count in MATCH_LEAST: %s", countStr)
|
||||
}
|
||||
if count <= 0 {
|
||||
return fmt.Errorf("count in MATCH_LEAST must be positive, got: %d", count)
|
||||
}
|
||||
|
||||
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchLeast, count, "MATCH_LEAST")
|
||||
}
|
||||
|
||||
// VisitMatchMost handles MATCH_MOST expressions
|
||||
// Syntax: MATCH_MOST(structArrayField, $[intField] == 1 && $[strField] == "aaa", threshold=N)
|
||||
// At most N elements must match the predicate
|
||||
func (v *ParserVisitor) VisitMatchMost(ctx *parser.MatchMostContext) interface{} {
|
||||
structArrayFieldName := ctx.Identifier().GetText()
|
||||
|
||||
countStr := ctx.IntegerConstant().GetText()
|
||||
count, err := strconv.ParseInt(countStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid count in MATCH_MOST: %s", countStr)
|
||||
}
|
||||
if count < 0 {
|
||||
return fmt.Errorf("count in MATCH_MOST cannot be negative, got: %d", count)
|
||||
}
|
||||
|
||||
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchMost, count, "MATCH_MOST")
|
||||
}
|
||||
|
||||
// VisitMatchExact handles MATCH_EXACT expressions
|
||||
// Syntax: MATCH_EXACT(structArrayField, $[intField] == 1 && $[strField] == "aaa", threshold=N)
|
||||
// Exactly N elements must match the predicate
|
||||
func (v *ParserVisitor) VisitMatchExact(ctx *parser.MatchExactContext) interface{} {
|
||||
structArrayFieldName := ctx.Identifier().GetText()
|
||||
|
||||
countStr := ctx.IntegerConstant().GetText()
|
||||
count, err := strconv.ParseInt(countStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid count in MATCH_EXACT: %s", countStr)
|
||||
}
|
||||
if count < 0 {
|
||||
return fmt.Errorf("count in MATCH_EXACT cannot be negative, got: %d", count)
|
||||
}
|
||||
|
||||
return v.parseMatchExpr(structArrayFieldName, ctx.Expr(), planpb.MatchType_MatchExact, count, "MATCH_EXACT")
|
||||
}
|
||||
|
||||
@ -2557,3 +2557,150 @@ func TestExpr_ElementFilter(t *testing.T) {
|
||||
assertInvalidExpr(t, helper, expr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpr_Match(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Valid MATCH_ALL expressions
|
||||
validExprs := []string{
|
||||
// MATCH_ALL: all elements must match
|
||||
`MATCH_ALL(struct_array, $[sub_int] > 1)`,
|
||||
`MATCH_ALL(struct_array, $[sub_int] == 100)`,
|
||||
`MATCH_ALL(struct_array, $[sub_str] == "aaa")`,
|
||||
`MATCH_ALL(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100)`,
|
||||
`MATCH_ALL(struct_array, $[sub_str] != "" || $[sub_int] >= 0)`,
|
||||
|
||||
// MATCH_ANY: at least one element must match
|
||||
`MATCH_ANY(struct_array, $[sub_int] > 1)`,
|
||||
`MATCH_ANY(struct_array, $[sub_int] == 100)`,
|
||||
`MATCH_ANY(struct_array, $[sub_str] == "aaa")`,
|
||||
`MATCH_ANY(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100)`,
|
||||
|
||||
// MATCH_LEAST: at least N elements must match
|
||||
`MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=3)`,
|
||||
`MATCH_LEAST(struct_array, $[sub_str] == "aaa", threshold=1)`,
|
||||
`MATCH_LEAST(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100, threshold=2)`,
|
||||
|
||||
// MATCH_MOST: at most N elements must match
|
||||
`MATCH_MOST(struct_array, $[sub_int] > 1, threshold=3)`,
|
||||
`MATCH_MOST(struct_array, $[sub_str] == "aaa", threshold=0)`,
|
||||
`MATCH_MOST(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100, threshold=5)`,
|
||||
|
||||
// MATCH_EXACT: exactly N elements must match
|
||||
`MATCH_EXACT(struct_array, $[sub_int] > 1, threshold=2)`,
|
||||
`MATCH_EXACT(struct_array, $[sub_str] == "aaa", threshold=0)`,
|
||||
`MATCH_EXACT(struct_array, $[sub_str] == "aaa" && $[sub_int] > 100, threshold=3)`,
|
||||
|
||||
// Combined with other expressions (match must be last)
|
||||
`Int64Field > 0 && MATCH_ALL(struct_array, $[sub_int] > 1)`,
|
||||
`Int64Field > 0 && MATCH_ANY(struct_array, $[sub_str] == "test")`,
|
||||
`Int64Field > 0 && MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=2)`,
|
||||
|
||||
// Complex predicates
|
||||
`MATCH_ALL(struct_array, ($[sub_int] > 0 && $[sub_int] < 100) || $[sub_str] == "default")`,
|
||||
`MATCH_ANY(struct_array, !($[sub_int] < 0))`,
|
||||
|
||||
// Case insensitivity
|
||||
`match_all(struct_array, $[sub_int] > 1)`,
|
||||
`match_any(struct_array, $[sub_int] > 1)`,
|
||||
`match_least(struct_array, $[sub_int] > 1, threshold=2)`,
|
||||
`match_most(struct_array, $[sub_int] > 1, threshold=2)`,
|
||||
`match_exact(struct_array, $[sub_int] > 1, threshold=2)`,
|
||||
|
||||
// Multiple match expressions with logical operators
|
||||
`MATCH_ALL(struct_array, $[sub_int] > 1) || MATCH_ANY(struct_array, $[sub_str] == "test")`,
|
||||
`MATCH_ALL(struct_array, $[sub_int] > 1) && MATCH_ANY(struct_array, $[sub_str] == "test")`,
|
||||
`MATCH_ANY(struct_array, $[sub_int] > 1) || Int64Field > 0`,
|
||||
`MATCH_ALL(struct_array, $[sub_int] > 1) && Int64Field > 0`,
|
||||
}
|
||||
|
||||
for _, expr := range validExprs {
|
||||
assertValidExpr(t, helper, expr)
|
||||
}
|
||||
|
||||
// Test proto structure assertions
|
||||
t.Run("MatchAll_Proto", func(t *testing.T) {
|
||||
expr, err := ParseExpr(helper, `MATCH_ALL(struct_array, $[sub_int] > 1)`, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, expr.GetMatchExpr())
|
||||
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
|
||||
assert.Equal(t, planpb.MatchType_MatchAll, expr.GetMatchExpr().GetMatchType())
|
||||
assert.Equal(t, int64(0), expr.GetMatchExpr().GetCount())
|
||||
})
|
||||
|
||||
t.Run("MatchAny_Proto", func(t *testing.T) {
|
||||
expr, err := ParseExpr(helper, `MATCH_ANY(struct_array, $[sub_str] == "aaa")`, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, expr.GetMatchExpr())
|
||||
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
|
||||
assert.Equal(t, planpb.MatchType_MatchAny, expr.GetMatchExpr().GetMatchType())
|
||||
assert.Equal(t, int64(0), expr.GetMatchExpr().GetCount())
|
||||
})
|
||||
|
||||
t.Run("MatchLeast_Proto", func(t *testing.T) {
|
||||
expr, err := ParseExpr(helper, `MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=3)`, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, expr.GetMatchExpr())
|
||||
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
|
||||
assert.Equal(t, planpb.MatchType_MatchLeast, expr.GetMatchExpr().GetMatchType())
|
||||
assert.Equal(t, int64(3), expr.GetMatchExpr().GetCount())
|
||||
})
|
||||
|
||||
t.Run("MatchMost_Proto", func(t *testing.T) {
|
||||
expr, err := ParseExpr(helper, `MATCH_MOST(struct_array, $[sub_str] == "aaa", threshold=5)`, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, expr.GetMatchExpr())
|
||||
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
|
||||
assert.Equal(t, planpb.MatchType_MatchMost, expr.GetMatchExpr().GetMatchType())
|
||||
assert.Equal(t, int64(5), expr.GetMatchExpr().GetCount())
|
||||
})
|
||||
|
||||
t.Run("MatchExact_Proto", func(t *testing.T) {
|
||||
expr, err := ParseExpr(helper, `MATCH_EXACT(struct_array, $[sub_int] == 100, threshold=2)`, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, expr.GetMatchExpr())
|
||||
assert.Equal(t, "struct_array", expr.GetMatchExpr().GetStructName())
|
||||
assert.Equal(t, planpb.MatchType_MatchExact, expr.GetMatchExpr().GetMatchType())
|
||||
assert.Equal(t, int64(2), expr.GetMatchExpr().GetCount())
|
||||
})
|
||||
|
||||
// Invalid expressions
|
||||
invalidExprs := []string{
|
||||
// Nested match expressions not allowed
|
||||
`MATCH_ALL(struct_array, MATCH_ANY(struct_array, $[sub_int] > 1))`,
|
||||
`MATCH_ANY(struct_array, $[sub_int] > 1 && MATCH_ALL(struct_array, $[sub_str] == "1"))`,
|
||||
|
||||
// $[field] syntax outside match context
|
||||
`$[sub_int] > 1`,
|
||||
`Int64Field > 0 && $[sub_int] > 1`,
|
||||
|
||||
// Non-existent fields
|
||||
`MATCH_ALL(struct_array, $[non_existent_field] > 1)`,
|
||||
`MATCH_ALL(non_existent_array, $[sub_int] > 1)`,
|
||||
|
||||
// Missing parameters
|
||||
`MATCH_ALL(struct_array)`,
|
||||
`MATCH_ALL()`,
|
||||
`MATCH_ANY(struct_array)`,
|
||||
`MATCH_ANY()`,
|
||||
`MATCH_LEAST(struct_array, $[sub_int] > 1)`, // missing count
|
||||
`MATCH_MOST(struct_array, $[sub_int] > 1)`, // missing count
|
||||
`MATCH_EXACT(struct_array, $[sub_int] > 1)`, // missing count
|
||||
|
||||
// MATCH_ALL/MATCH_ANY should not have count parameter
|
||||
`MATCH_ALL(struct_array, $[sub_int] > 1, 3)`,
|
||||
`MATCH_ANY(struct_array, $[sub_int] > 1, 2)`,
|
||||
|
||||
// Invalid count values
|
||||
`MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=0)`, // count must be positive for MATCH_LEAST
|
||||
`MATCH_LEAST(struct_array, $[sub_int] > 1, threshold=-1)`, // negative count
|
||||
`MATCH_MOST(struct_array, $[sub_int] > 1, threshold=-1)`, // negative count
|
||||
`MATCH_EXACT(struct_array, $[sub_int] > 1, threshold=-1)`, // negative count
|
||||
}
|
||||
|
||||
for _, expr := range invalidExprs {
|
||||
assertInvalidExpr(t, helper, expr)
|
||||
}
|
||||
}
|
||||
|
||||
@ -362,8 +362,14 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr,
|
||||
|
||||
func handleCompareRightValue(op planpb.OpType, left *ExprWithType, right *planpb.ValueExpr) (*planpb.Expr, error) {
|
||||
dataType := left.dataType
|
||||
if typeutil.IsArrayType(dataType) && len(toColumnInfo(left).GetNestedPath()) != 0 {
|
||||
dataType = toColumnInfo(left).GetElementType()
|
||||
columnInfo := toColumnInfo(left)
|
||||
|
||||
// Use element type for casting in two cases:
|
||||
// 1. Array with nested path (e.g., arr[0])
|
||||
// 2. Array with element level flag (e.g., $[intField] in MATCH_ALL/ElementFilter)
|
||||
if typeutil.IsArrayType(dataType) && columnInfo != nil &&
|
||||
(len(columnInfo.GetNestedPath()) != 0 || columnInfo.GetIsElementLevel()) {
|
||||
dataType = columnInfo.GetElementType()
|
||||
}
|
||||
|
||||
if !left.expr.GetIsTemplate() && !isTemplateExpr(right) {
|
||||
@ -378,7 +384,6 @@ func handleCompareRightValue(op planpb.OpType, left *ExprWithType, right *planpb
|
||||
return handleBinaryArithExpr(op, leftArithExpr, left.dataType, right)
|
||||
}
|
||||
|
||||
columnInfo := toColumnInfo(left)
|
||||
if columnInfo == nil {
|
||||
return nil, errors.New("not supported to combine multiple fields")
|
||||
}
|
||||
|
||||
@ -252,6 +252,22 @@ message ElementFilterExpr {
|
||||
Expr predicate = 3;
|
||||
}
|
||||
|
||||
// MatchType defines the type of match operation for struct array queries
|
||||
enum MatchType {
|
||||
MatchAll = 0; // All elements must match the predicate
|
||||
MatchAny = 1; // At least one element matches the predicate
|
||||
MatchLeast = 2; // At least N elements match the predicate
|
||||
MatchMost = 3; // At most N elements match the predicate
|
||||
MatchExact = 4; // Exactly N elements match the predicate
|
||||
}
|
||||
|
||||
message MatchExpr {
|
||||
string struct_name = 1; // The struct array field name (e.g., struct_array)
|
||||
Expr predicate = 2; // The condition expression using $[field] syntax (e.g., $[intField] == 1 && $[strField] == "aaa")
|
||||
MatchType match_type = 3; // Type of match operation
|
||||
int64 count = 4; // For MatchLeast/MatchMost: the count parameter (N)
|
||||
}
|
||||
|
||||
message AlwaysTrueExpr {}
|
||||
|
||||
message Interval {
|
||||
@ -293,6 +309,7 @@ message Expr {
|
||||
GISFunctionFilterExpr gisfunction_filter_expr = 17;
|
||||
TimestamptzArithCompareExpr timestamptz_arith_compare_expr = 18;
|
||||
ElementFilterExpr element_filter_expr = 19;
|
||||
MatchExpr match_expr = 21;
|
||||
};
|
||||
bool is_template = 20;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user