From 601a8b801bfa1b3a69084bf0e63d32ea5bd31361 Mon Sep 17 00:00:00 2001 From: zhagnlu <1542303831@qq.com> Date: Tue, 9 Jan 2024 17:08:48 +0800 Subject: [PATCH] fix: add move cursor function to physical expr (#29603) #29570 Signed-off-by: luzhang Co-authored-by: luzhang --- .../core/src/exec/expression/AlwaysTrueExpr.h | 9 +++ .../core/src/exec/expression/CompareExpr.h | 20 ++++++ .../core/src/exec/expression/ConjunctExpr.cpp | 13 +++- .../core/src/exec/expression/ConjunctExpr.h | 12 +++- internal/core/src/exec/expression/Expr.h | 53 ++++++++++++++ .../src/exec/expression/LogicalBinaryExpr.h | 6 ++ .../src/exec/expression/LogicalUnaryExpr.h | 5 ++ internal/core/unittest/test_expr.cpp | 71 +++++++++++++++++++ 8 files changed, 186 insertions(+), 3 deletions(-) diff --git a/internal/core/src/exec/expression/AlwaysTrueExpr.h b/internal/core/src/exec/expression/AlwaysTrueExpr.h index c2acd9ba15..ffb5750a31 100644 --- a/internal/core/src/exec/expression/AlwaysTrueExpr.h +++ b/internal/core/src/exec/expression/AlwaysTrueExpr.h @@ -45,6 +45,15 @@ class PhyAlwaysTrueExpr : public Expr { void Eval(EvalCtx& context, VectorPtr& result) override; + void + MoveCursor() override { + int64_t real_batch_size = current_pos_ + batch_size_ >= active_count_ + ? active_count_ - current_pos_ + : batch_size_; + + current_pos_ += real_batch_size; + } + private: std::shared_ptr expr_; int64_t active_count_; diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index c05974eb54..392e6d21c7 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -115,6 +115,26 @@ class PhyCompareFilterExpr : public Expr { void Eval(EvalCtx& context, VectorPtr& result) override; + void + MoveCursor() override { + int64_t processed_rows = 0; + for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; + ++chunk_id) { + auto chunk_size = chunk_id == num_chunk_ - 1 + ? active_count_ - chunk_id * size_per_chunk_ + : size_per_chunk_; + + for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; + i < chunk_size; + ++i) { + if (++processed_rows >= batch_size_) { + current_chunk_id_ = chunk_id; + current_chunk_pos_ = i + 1; + } + } + } + } + private: int64_t GetNextBatchSize(); diff --git a/internal/core/src/exec/expression/ConjunctExpr.cpp b/internal/core/src/exec/expression/ConjunctExpr.cpp index 1c1498b11e..a26b98dda7 100644 --- a/internal/core/src/exec/expression/ConjunctExpr.cpp +++ b/internal/core/src/exec/expression/ConjunctExpr.cpp @@ -97,13 +97,20 @@ PhyConjunctFilterExpr::UpdateResult(ColumnVectorPtr& input_result, } bool -PhyConjunctFilterExpr::CanSkipNextExprs(ColumnVectorPtr& vec) { +PhyConjunctFilterExpr::CanSkipFollowingExprs(ColumnVectorPtr& vec) { if ((is_and_ && AllFalse(vec)) || (!is_and_ && AllTrue(vec))) { return true; } return false; } +void +PhyConjunctFilterExpr::SkipFollowingExprs(int start) { + for (int i = start; i < inputs_.size(); ++i) { + inputs_[i]->MoveCursor(); + } +} + void PhyConjunctFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { for (int i = 0; i < inputs_.size(); ++i) { @@ -112,7 +119,8 @@ PhyConjunctFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { if (i == 0) { result = input_result; auto all_flat_result = GetColumnVector(result); - if (CanSkipNextExprs(all_flat_result)) { + if (CanSkipFollowingExprs(all_flat_result)) { + SkipFollowingExprs(i + 1); return; } continue; @@ -122,6 +130,7 @@ PhyConjunctFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { auto active_rows = UpdateResult(input_flat_result, context, all_flat_result); if (active_rows == 0) { + SkipFollowingExprs(i + 1); return; } } diff --git a/internal/core/src/exec/expression/ConjunctExpr.h b/internal/core/src/exec/expression/ConjunctExpr.h index 6027f56b60..bd81059977 100644 --- a/internal/core/src/exec/expression/ConjunctExpr.h +++ b/internal/core/src/exec/expression/ConjunctExpr.h @@ -70,6 +70,13 @@ class PhyConjunctFilterExpr : public Expr { void Eval(EvalCtx& context, VectorPtr& result) override; + void + MoveCursor() override { + for (auto& input : inputs_) { + input->MoveCursor(); + } + } + private: int64_t UpdateResult(ColumnVectorPtr& input_result, @@ -80,7 +87,10 @@ class PhyConjunctFilterExpr : public Expr { ResolveType(const std::vector& inputs); bool - CanSkipNextExprs(ColumnVectorPtr& vec); + CanSkipFollowingExprs(ColumnVectorPtr& vec); + + void + SkipFollowingExprs(int start); // true if conjunction (and), false if disjunction (or). bool is_and_; std::vector input_order_; diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index 6cd981f1a1..eb474e92e5 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -66,6 +66,12 @@ class Expr { Eval(EvalCtx& context, VectorPtr& result) { } + // Only move cursor to next batch + // but not do real eval for optimization + virtual void + MoveCursor() { + } + protected: DataType type_; const std::vector> inputs_; @@ -118,6 +124,53 @@ class SegmentExpr : public Expr { } } + void + MoveCursorForData() { + if (segment_->type() == SegmentType::Sealed) { + auto size = + std::min(active_count_ - current_data_chunk_pos_, batch_size_); + current_data_chunk_pos_ += size; + } else { + int64_t processed_size = 0; + for (size_t i = current_data_chunk_; i < num_data_chunk_; i++) { + auto data_pos = + (i == current_data_chunk_) ? current_data_chunk_pos_ : 0; + auto size = (i == (num_data_chunk_ - 1) && + active_count_ % size_per_chunk_ != 0) + ? active_count_ % size_per_chunk_ - data_pos + : size_per_chunk_ - data_pos; + + size = std::min(size, batch_size_ - processed_size); + + processed_size += size; + if (processed_size >= batch_size_) { + current_data_chunk_ = i; + current_data_chunk_pos_ = data_pos + size; + break; + } + } + } + } + + void + MoveCursorForIndex() { + AssertInfo(segment_->type() == SegmentType::Sealed, + "index mode only for sealed segment"); + auto size = + std::min(active_count_ - current_index_chunk_pos_, batch_size_); + + current_index_chunk_pos_ += size; + } + + void + MoveCursor() override { + if (is_index_mode_) { + MoveCursorForIndex(); + } else { + MoveCursorForData(); + } + } + int64_t GetNextBatchSize() { auto current_chunk = diff --git a/internal/core/src/exec/expression/LogicalBinaryExpr.h b/internal/core/src/exec/expression/LogicalBinaryExpr.h index 25dfb4d934..c94df0b8b8 100644 --- a/internal/core/src/exec/expression/LogicalBinaryExpr.h +++ b/internal/core/src/exec/expression/LogicalBinaryExpr.h @@ -69,6 +69,12 @@ class PhyLogicalBinaryExpr : public Expr { void Eval(EvalCtx& context, VectorPtr& result) override; + void + MoveCursor() override { + inputs_[0]->MoveCursor(); + inputs_[1]->MoveCursor(); + } + private: std::shared_ptr expr_; }; diff --git a/internal/core/src/exec/expression/LogicalUnaryExpr.h b/internal/core/src/exec/expression/LogicalUnaryExpr.h index bc7d9a526a..da5a0e0c97 100644 --- a/internal/core/src/exec/expression/LogicalUnaryExpr.h +++ b/internal/core/src/exec/expression/LogicalUnaryExpr.h @@ -39,6 +39,11 @@ class PhyLogicalUnaryExpr : public Expr { void Eval(EvalCtx& context, VectorPtr& result) override; + void + MoveCursor() override { + inputs_[0]->MoveCursor(); + } + private: std::shared_ptr expr_; }; diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 65f60dfdf4..1316a623da 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -1927,6 +1927,77 @@ TEST(Expr, TestGrowingSegmentGetBatchSize) { } } +TEST(Expr, TestConjuctExpr) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto int8_fid = schema->AddDebugField("int8", DataType::INT8); + auto int8_1_fid = schema->AddDebugField("int81", DataType::INT8); + auto int16_fid = schema->AddDebugField("int16", DataType::INT16); + auto int16_1_fid = schema->AddDebugField("int161", DataType::INT16); + auto int32_fid = schema->AddDebugField("int32", DataType::INT32); + auto int32_1_fid = schema->AddDebugField("int321", DataType::INT32); + auto int64_fid = schema->AddDebugField("int64", DataType::INT64); + auto int64_1_fid = schema->AddDebugField("int641", DataType::INT64); + auto str1_fid = schema->AddDebugField("string1", DataType::VARCHAR); + auto str2_fid = schema->AddDebugField("string2", DataType::VARCHAR); + auto float_fid = schema->AddDebugField("float", DataType::FLOAT); + auto double_fid = schema->AddDebugField("double", DataType::DOUBLE); + schema->set_primary_field_id(str1_fid); + + auto seg = CreateSealedSegment(schema); + int N = 1000000; + auto raw_data = DataGen(schema, N); + // load field data + auto fields = schema->get_fields(); + for (auto field_data : raw_data.raw_->fields_data()) { + int64_t field_id = field_data.field_id(); + + auto info = FieldDataInfo(field_data.field_id(), N, "/tmp/a"); + auto field_meta = fields.at(FieldId(field_id)); + info.channel->push( + CreateFieldDataFromDataArray(N, &field_data, field_meta)); + info.channel->close(); + + seg->LoadFieldData(FieldId(field_id), info); + } + query::ExecPlanNodeVisitor visitor(*seg, MAX_TIMESTAMP); + + auto build_expr = [&](int l, int r) -> expr::TypedExprPtr { + ::milvus::proto::plan::GenericValue value; + value.set_int64_val(l); + auto left = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::GreaterThan, + value); + value.set_int64_val(r); + auto right = std::make_shared( + expr::ColumnInfo(int64_fid, DataType::INT64), + proto::plan::OpType::LessThan, + value); + + return std::make_shared( + expr::LogicalBinaryExpr::OpType::And, left, right); + }; + + std::vector> test_case = { + {100, 0}, {0, 100}, {8192, 8194}}; + for (auto& pair : test_case) { + std::cout << pair.first << "|" << pair.second << std::endl; + auto expr = build_expr(pair.first, pair.second); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + BitsetType final; + visitor.ExecuteExprNode(plan, seg.get(), N, final); + for (int i = 0; i < N; ++i) { + EXPECT_EQ(final[i], pair.first < i && i < pair.second) << i; + } + } +} + TEST(Expr, TestUnaryBenchTest) { using namespace milvus; using namespace milvus::query;