diff --git a/internal/core/src/exec/QueryContext.h b/internal/core/src/exec/QueryContext.h index c9014c2174..02dc46668f 100644 --- a/internal/core/src/exec/QueryContext.h +++ b/internal/core/src/exec/QueryContext.h @@ -23,14 +23,15 @@ #include #include #include +#include #include "common/Common.h" #include "common/Types.h" #include "common/Exception.h" +#include "common/OpContext.h" #include "segcore/SegmentInterface.h" -namespace milvus { -namespace exec { +namespace milvus::exec { enum class ContextScope { GLOBAL = 0, SESSION = 1, QUERY = 2, Executor = 3 }; @@ -119,7 +120,8 @@ class QueryConfig : public MemConfig { static constexpr const char* kExprEvalBatchSize = "expression.eval_batch_size"; - QueryConfig(const std::unordered_map& values) + explicit QueryConfig( + const std::unordered_map& values) : MemConfig(values) { } @@ -335,7 +337,7 @@ class QueryContext : public Context { // TODO: add more class member such as memory pool class ExecContext : public Context { public: - ExecContext(QueryContext* query_context) + explicit ExecContext(QueryContext* query_context) : Context(ContextScope::Executor), query_context_(query_context) { } @@ -353,5 +355,20 @@ class ExecContext : public Context { QueryContext* query_context_; }; -} // namespace exec -} // namespace milvus \ No newline at end of file +/// @brief Helper function to check cancellation token and throw if cancelled. +/// This function safely checks the cancellation token from QueryContext and throws +/// folly::FutureCancellation if the operation has been cancelled. +/// @param query_context Pointer to QueryContext (can be nullptr) +inline void +checkCancellation(QueryContext* query_context) { + if (query_context == nullptr) { + return; + } + auto* op_context = query_context->get_op_context(); + if (op_context != nullptr && + op_context->cancellation_token.isCancellationRequested()) { + throw folly::FutureCancellation(); + } +} + +} // namespace milvus::exec diff --git a/internal/core/src/exec/expression/Expr.cpp b/internal/core/src/exec/expression/Expr.cpp index 6650cda03c..1edcf44c9a 100644 --- a/internal/core/src/exec/expression/Expr.cpp +++ b/internal/core/src/exec/expression/Expr.cpp @@ -50,7 +50,11 @@ ExprSet::Eval(int32_t begin, EvalCtx& context, std::vector& results) { results.resize(exprs_.size()); + auto* exec_ctx = context.get_exec_context(); + auto* query_ctx = + exec_ctx != nullptr ? exec_ctx->get_query_context() : nullptr; for (size_t i = begin; i < end; ++i) { + milvus::exec::checkCancellation(query_ctx); exprs_[i]->Eval(context, results[i]); } } diff --git a/internal/core/src/exec/expression/ExprTest.cpp b/internal/core/src/exec/expression/ExprTest.cpp index 92c7d7f050..1c144d51c5 100644 --- a/internal/core/src/exec/expression/ExprTest.cpp +++ b/internal/core/src/exec/expression/ExprTest.cpp @@ -32,6 +32,8 @@ #include "common/Json.h" #include "common/JsonCastType.h" #include "common/Types.h" +#include "common/Exception.h" +#include "folly/CancellationToken.h" #include "gtest/gtest.h" #include "index/Meta.h" #include "index/JsonInvertedIndex.h" @@ -12378,7 +12380,7 @@ TEST_P(ExprTest, TestTermWithJSON) { std::this_thread::sleep_for(std::chrono::milliseconds(200) * 2); auto seg_promote = dynamic_cast(seg.get()); query::ExecPlanNodeVisitor visitor(*seg_promote, MAX_TIMESTAMP); - int offset = 0; + for (auto [clause, ref_func, dtype] : testcases) { auto loc = serialized_expr_plan.find("@@@@@"); auto expr_plan = serialized_expr_plan; @@ -17654,3 +17656,102 @@ TEST(ExprTest, ParseGISFunctionFilterExprsMultipleOps) { EXPECT_EQ(sr->total_nq_, 5) << "Failed for operation: " << op; } } + +TEST_P(ExprTest, TestCancellationInExprEval) { + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto int64_fid = schema->AddDebugField("counter", DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + auto raw_data = DataGen(schema, N); + seg->PreInsert(N); + seg->Insert(0, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + + // Test cancellation during expression evaluation + auto seg_promote = dynamic_cast(seg.get()); + + // Create a cancellation source and token + folly::CancellationSource cancellation_source; + auto cancellation_token = cancellation_source.getToken(); + + // Create a query that will be cancelled + std::string serialized_expr_plan = R"(vector_anns: < + field_id: 100 + predicates: < + unary_range_expr: < + column_info: < + field_id: 101 + data_type: Int64 + > + op: GreaterThan + value: < + int64_val: 500 + > + > + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + // Request cancellation before executing + cancellation_source.requestCancellation(); + + // Try to execute query with cancelled token - should throw + query::ExecPlanNodeVisitor visitor( + *seg_promote, MAX_TIMESTAMP, cancellation_token); + + auto proto = std::make_unique(); + auto ok = google::protobuf::TextFormat::ParseFromString( + serialized_expr_plan, proto.get()); + ASSERT_TRUE(ok); + auto plan = CreateSearchPlanByExpr(schema, + proto->SerializeAsString().data(), + proto->SerializeAsString().size()); + + // This should throw ExecOperatorException (wrapping FutureCancellation) when visiting the plan + ASSERT_THROW({ auto result = visitor.get_moved_result(*plan->plan_node_); }, + milvus::ExecOperatorException); +} + +TEST(ExprTest, TestCancellationHelper) { + // Test that checkCancellation does nothing when query_context is nullptr + ASSERT_NO_THROW(milvus::exec::checkCancellation(nullptr)); + + // Test with valid query_context but no op_context + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + auto seg_promote = dynamic_cast(seg.get()); + + auto query_context = std::make_unique( + "test_query", seg_promote, 0, MAX_TIMESTAMP); + + // Should not throw when op_context is nullptr + ASSERT_NO_THROW(milvus::exec::checkCancellation(query_context.get())); + + // Test with cancelled token + folly::CancellationSource source; + milvus::OpContext op_context(source.getToken()); + query_context->set_op_context(&op_context); + + // Should not throw when not cancelled + ASSERT_NO_THROW(milvus::exec::checkCancellation(query_context.get())); + + // Cancel and test + source.requestCancellation(); + ASSERT_THROW(milvus::exec::checkCancellation(query_context.get()), + folly::FutureCancellation); +} \ No newline at end of file diff --git a/internal/core/src/exec/operator/CountNode.cpp b/internal/core/src/exec/operator/CountNode.cpp index 36b21d97e5..69aa776fa7 100644 --- a/internal/core/src/exec/operator/CountNode.cpp +++ b/internal/core/src/exec/operator/CountNode.cpp @@ -53,6 +53,8 @@ PhyCountNode::AddInput(RowVectorPtr& input) { RowVectorPtr PhyCountNode::GetOutput() { + milvus::exec::checkCancellation(query_context_); + if (is_finished_ || !no_more_input_) { return nullptr; } diff --git a/internal/core/src/exec/operator/FilterBitsNode.cpp b/internal/core/src/exec/operator/FilterBitsNode.cpp index 550037001b..83904e0689 100644 --- a/internal/core/src/exec/operator/FilterBitsNode.cpp +++ b/internal/core/src/exec/operator/FilterBitsNode.cpp @@ -59,6 +59,8 @@ PhyFilterBitsNode::IsFinished() { RowVectorPtr PhyFilterBitsNode::GetOutput() { + milvus::exec::checkCancellation(query_context_); + if (AllInputProcessed()) { return nullptr; } diff --git a/internal/core/src/exec/operator/GroupByNode.cpp b/internal/core/src/exec/operator/GroupByNode.cpp index 66acad86d6..80f7b9726a 100644 --- a/internal/core/src/exec/operator/GroupByNode.cpp +++ b/internal/core/src/exec/operator/GroupByNode.cpp @@ -44,6 +44,8 @@ PhyGroupByNode::AddInput(RowVectorPtr& input) { RowVectorPtr PhyGroupByNode::GetOutput() { + milvus::exec::checkCancellation(query_context_); + if (is_finished_ || !no_more_input_) { return nullptr; } diff --git a/internal/core/src/exec/operator/IterativeFilterNode.cpp b/internal/core/src/exec/operator/IterativeFilterNode.cpp index 5f73cf8e9c..4e8347f233 100644 --- a/internal/core/src/exec/operator/IterativeFilterNode.cpp +++ b/internal/core/src/exec/operator/IterativeFilterNode.cpp @@ -112,6 +112,8 @@ insert_helper(milvus::SearchResult& search_result, RowVectorPtr PhyIterativeFilterNode::GetOutput() { + milvus::exec::checkCancellation(query_context_); + if (is_finished_ || !no_more_input_) { return nullptr; } diff --git a/internal/core/src/exec/operator/MvccNode.cpp b/internal/core/src/exec/operator/MvccNode.cpp index 9491f3b3c0..6d9fc3e4b6 100644 --- a/internal/core/src/exec/operator/MvccNode.cpp +++ b/internal/core/src/exec/operator/MvccNode.cpp @@ -43,6 +43,10 @@ PhyMvccNode::AddInput(RowVectorPtr& input) { RowVectorPtr PhyMvccNode::GetOutput() { + auto* query_context = + operator_context_->get_exec_context()->get_query_context(); + milvus::exec::checkCancellation(query_context); + if (is_finished_) { return nullptr; } diff --git a/internal/core/src/exec/operator/RandomSampleNode.cpp b/internal/core/src/exec/operator/RandomSampleNode.cpp index 8449698948..9acde896e1 100644 --- a/internal/core/src/exec/operator/RandomSampleNode.cpp +++ b/internal/core/src/exec/operator/RandomSampleNode.cpp @@ -89,6 +89,10 @@ PhyRandomSampleNode::Sample(const uint32_t N, const float factor) { RowVectorPtr PhyRandomSampleNode::GetOutput() { + auto* query_context = + operator_context_->get_exec_context()->get_query_context(); + milvus::exec::checkCancellation(query_context); + if (is_finished_) { return nullptr; } @@ -167,6 +171,7 @@ PhyRandomSampleNode::GetOutput() { milvus::monitor::internal_core_search_latency_random_sample.Observe( duration / 1000); is_finished_ = true; + return result; } diff --git a/internal/core/src/exec/operator/RescoresNode.cpp b/internal/core/src/exec/operator/RescoresNode.cpp index 306a235ce8..4f0a055bbb 100644 --- a/internal/core/src/exec/operator/RescoresNode.cpp +++ b/internal/core/src/exec/operator/RescoresNode.cpp @@ -48,6 +48,10 @@ PhyRescoresNode::IsFinished() { RowVectorPtr PhyRescoresNode::GetOutput() { + ExecContext* exec_context = operator_context_->get_exec_context(); + auto query_context_ = exec_context->get_query_context(); + milvus::exec::checkCancellation(query_context_); + if (is_finished_ || !no_more_input_) { return nullptr; } @@ -61,12 +65,10 @@ PhyRescoresNode::GetOutput() { std::chrono::high_resolution_clock::time_point scalar_start = std::chrono::high_resolution_clock::now(); - ExecContext* exec_context = operator_context_->get_exec_context(); - auto query_context_ = exec_context->get_query_context(); auto query_info = exec_context->get_query_config(); milvus::SearchResult search_result = query_context_->get_search_result(); auto segment = query_context_->get_segment(); - auto op_ctx = query_context_->get_op_context(); + auto op_context = query_context_->get_op_context(); // prepare segment offset FixedVector offsets; @@ -96,7 +98,7 @@ PhyRescoresNode::GetOutput() { // boost for all result if no filter if (!filter) { scorer->batch_score( - op_ctx, segment, function_mode, offsets, boost_scores); + op_context, segment, function_mode, offsets, boost_scores); continue; } @@ -122,7 +124,7 @@ PhyRescoresNode::GetOutput() { auto col_vec = std::dynamic_pointer_cast(results[0]); auto col_vec_size = col_vec->size(); TargetBitmapView bitsetview(col_vec->GetRawData(), col_vec_size); - scorer->batch_score(op_ctx, + scorer->batch_score(op_context, segment, function_mode, offsets, @@ -138,8 +140,12 @@ PhyRescoresNode::GetOutput() { auto col_vec_size = col_vec->size(); TargetBitmapView view(col_vec->GetRawData(), col_vec_size); bitset.append(view); - scorer->batch_score( - op_ctx, segment, function_mode, offsets, bitset, boost_scores); + scorer->batch_score(op_context, + segment, + function_mode, + offsets, + bitset, + boost_scores); } } @@ -183,4 +189,4 @@ PhyRescoresNode::GetOutput() { return input_; }; -} // namespace milvus::exec \ No newline at end of file +} // namespace milvus::exec diff --git a/internal/core/src/exec/operator/RescoresNodeTest.cpp b/internal/core/src/exec/operator/RescoresNodeTest.cpp index f1cd0c567f..66fe01a5b5 100644 --- a/internal/core/src/exec/operator/RescoresNodeTest.cpp +++ b/internal/core/src/exec/operator/RescoresNodeTest.cpp @@ -11,6 +11,7 @@ #include #include "common/Schema.h" +#include "common/Types.h" #include "query/Plan.h" #include "segcore/reduce_c.h" @@ -96,7 +97,7 @@ TEST(Rescorer, Normal) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); } // search result not empty but no boost filter @@ -140,7 +141,7 @@ TEST(Rescorer, Normal) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); } // random function with seed @@ -186,7 +187,7 @@ TEST(Rescorer, Normal) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); } // random function with field as random seed @@ -232,7 +233,7 @@ TEST(Rescorer, Normal) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); } // random function with field and seed @@ -279,10 +280,10 @@ TEST(Rescorer, Normal) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); auto search_result_same_seed = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); // should return same score when use same seed for (auto i = 0; i < 10; i++) { diff --git a/internal/core/src/exec/operator/VectorSearchNode.cpp b/internal/core/src/exec/operator/VectorSearchNode.cpp index b98f240107..0dd0c0aef2 100644 --- a/internal/core/src/exec/operator/VectorSearchNode.cpp +++ b/internal/core/src/exec/operator/VectorSearchNode.cpp @@ -55,6 +55,8 @@ PhyVectorSearchNode::AddInput(RowVectorPtr& input) { RowVectorPtr PhyVectorSearchNode::GetOutput() { + milvus::exec::checkCancellation(query_context_); + if (is_finished_ || !no_more_input_) { return nullptr; } diff --git a/internal/core/src/futures/Future.h b/internal/core/src/futures/Future.h index cea6293f56..d90dfe317c 100644 --- a/internal/core/src/futures/Future.h +++ b/internal/core/src/futures/Future.h @@ -55,8 +55,7 @@ class Metrics { explicit Metrics() : time_point_(std::chrono::steady_clock::now()), queue_duration_(0), - execute_duration_(0), - cancelled_before_execute_(false) { + execute_duration_(0) { milvus::monitor::internal_cgo_inflight_task_total_all.Increment(); } @@ -75,13 +74,17 @@ class Metrics { milvus::monitor::internal_cgo_cancel_before_execute_total_all .Increment(); } else { + if (cancelled_during_execute_) { + milvus::monitor::internal_cgo_cancel_during_execute_total_all + .Increment(); + } milvus::monitor::internal_cgo_execute_duration_seconds_all.Observe( std::chrono::duration(execute_duration_).count()); } } void - withCancel() { + withEarlyCancel() { queue_duration_ = std::chrono::duration_cast( std::chrono::steady_clock::now() - time_point_); cancelled_before_execute_ = true; @@ -96,6 +99,11 @@ class Metrics { milvus::monitor::internal_cgo_executing_task_total_all.Increment(); } + void + withDuringCancel() { + cancelled_during_execute_ = true; + } + void executeDone() { auto now = std::chrono::steady_clock::now(); @@ -108,7 +116,8 @@ class Metrics { std::chrono::steady_clock::time_point time_point_; Duration queue_duration_; Duration execute_duration_; - bool cancelled_before_execute_; + bool cancelled_before_execute_{false}; + bool cancelled_during_execute_{false}; }; // FutureResult is a struct that represents the result of the future. @@ -156,21 +165,13 @@ class IFuture { virtual ~IFuture() = default; }; -/// @brief a class that represents a cancellation token -class CancellationToken : public folly::CancellationToken { - public: - CancellationToken(folly::CancellationToken&& token) noexcept - : folly::CancellationToken(std::move(token)) { +/// @brief a helper function to throw a FutureCancellation exception if the token is cancelled. +static inline void +throwIfCancelled(const folly::CancellationToken& token) { + if (token.isCancellationRequested()) { + throw folly::FutureCancellation(); } - - /// @brief check if the token is cancelled, throw a FutureCancellation exception if it is. - void - throwIfCancelled() const { - if (isCancellationRequested()) { - throw folly::FutureCancellation(); - } - } -}; +} /// @brief Future is a class that bound a future with a result for /// using by cgo. @@ -183,7 +184,7 @@ class Future : public IFuture { /// returned result or exception will be handled by consumer side. template >> + std::is_invocable_r_v>> static std::unique_ptr> async(folly::Executor::KeepAlive<> executor, int priority, @@ -267,23 +268,32 @@ class Future : public IFuture { template >> + std::is_invocable_r_v>> void asyncProduce(folly::Executor::KeepAlive<> executor, int priority, Fn&& fn) { // start produce process async. - auto cancellation_token = - CancellationToken(cancellation_source_.getToken()); + auto cancellation_token = cancellation_source_.getToken(); auto runner = [fn = std::forward(fn), cancellation_token = std::move(cancellation_token), this]() { + // pre check the cancellation token if (cancellation_token.isCancellationRequested()) { - metrics_.withCancel(); + metrics_.withEarlyCancel(); throw folly::FutureCancellation(); } - auto executionGuard = - Metrics::ExecutionGuard(metrics_); - return fn(cancellation_token); + // start the execution guard. + Metrics::ExecutionGuard executionGuard( + metrics_); + + try { + return fn(cancellation_token); + } catch (const folly::FutureCancellation& e) { + metrics_.withDuringCancel(); + throw e; + } catch (...) { + throw; // rethrow the exception to the consumer side. + } }; // the runner is executed may be executed in different thread. diff --git a/internal/core/src/futures/FutureTest.cpp b/internal/core/src/futures/FutureTest.cpp index 671cffc72a..44a6c8bff1 100644 --- a/internal/core/src/futures/FutureTest.cpp +++ b/internal/core/src/futures/FutureTest.cpp @@ -84,7 +84,7 @@ TEST(Futures, Future) { { // try a async function auto future = milvus::futures::Future::async( - &executor, 0, [](milvus::futures::CancellationToken token) { + &executor, 0, [](folly::CancellationToken token) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); return new int(1); }); @@ -110,7 +110,7 @@ TEST(Futures, Future) { { // try a async function auto future = milvus::futures::Future::async( - &executor, 0, [](milvus::futures::CancellationToken token) { + &executor, 0, [](folly::CancellationToken token) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); throw milvus::SegcoreError(milvus::NotImplemented, "unimplemented"); @@ -136,7 +136,7 @@ TEST(Futures, Future) { { // try a async function auto future = milvus::futures::Future::async( - &executor, 0, [](milvus::futures::CancellationToken token) { + &executor, 0, [](folly::CancellationToken token) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); throw std::runtime_error("unimplemented"); return new int(1); @@ -160,7 +160,7 @@ TEST(Futures, Future) { { // try a async function auto future = milvus::futures::Future::async( - &executor, 0, [](milvus::futures::CancellationToken token) { + &executor, 0, [](folly::CancellationToken token) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); throw folly::FutureNotReady(); return new int(1); @@ -185,9 +185,9 @@ TEST(Futures, Future) { { // try a async function auto future = milvus::futures::Future::async( - &executor, 0, [](milvus::futures::CancellationToken token) { + &executor, 0, [](folly::CancellationToken token) { for (int i = 0; i < 10; i++) { - token.throwIfCancelled(); + milvus::futures::throwIfCancelled(token); std::this_thread::sleep_for(std::chrono::milliseconds(100)); } return new int(1); diff --git a/internal/core/src/futures/future_test_case_c.cpp b/internal/core/src/futures/future_test_case_c.cpp index bdf7ccf130..4de12922b3 100644 --- a/internal/core/src/futures/future_test_case_c.cpp +++ b/internal/core/src/futures/future_test_case_c.cpp @@ -18,10 +18,10 @@ future_create_test_case(int interval, int loop_cnt, int case_no) { milvus::futures::getGlobalCPUExecutor(), milvus::futures::ExecutePriority::HIGH, [interval = interval, loop_cnt = loop_cnt, case_no = case_no]( - milvus::futures::CancellationToken token) { + const folly::CancellationToken& token) { for (int i = 0; i < loop_cnt; i++) { if (case_no != 0) { - token.throwIfCancelled(); + milvus::futures::throwIfCancelled(token); } std::this_thread::sleep_for( std::chrono::milliseconds(interval)); diff --git a/internal/core/src/monitor/Monitor.cpp b/internal/core/src/monitor/Monitor.cpp index 674cfb7b90..10d63d6a04 100644 --- a/internal/core/src/monitor/Monitor.cpp +++ b/internal/core/src/monitor/Monitor.cpp @@ -284,6 +284,12 @@ DEFINE_PROMETHEUS_COUNTER(internal_cgo_cancel_before_execute_total_all, internal_cgo_cancel_before_execute_total, {}); +DEFINE_PROMETHEUS_COUNTER_FAMILY(internal_cgo_cancel_during_execute_total, + "[cpp]async cgo cancel during execute count"); +DEFINE_PROMETHEUS_COUNTER(internal_cgo_cancel_during_execute_total_all, + internal_cgo_cancel_during_execute_total, + {}); + DEFINE_PROMETHEUS_GAUGE_FAMILY(internal_cgo_pool_size, "[cpp]async cgo pool size"); DEFINE_PROMETHEUS_GAUGE(internal_cgo_pool_size_all, internal_cgo_pool_size, {}); diff --git a/internal/core/src/monitor/Monitor.h b/internal/core/src/monitor/Monitor.h index 0366805b9c..39e2629c73 100644 --- a/internal/core/src/monitor/Monitor.h +++ b/internal/core/src/monitor/Monitor.h @@ -88,6 +88,8 @@ DECLARE_PROMETHEUS_HISTOGRAM_FAMILY(internal_cgo_execute_duration_seconds); DECLARE_PROMETHEUS_HISTOGRAM(internal_cgo_execute_duration_seconds_all); DECLARE_PROMETHEUS_COUNTER_FAMILY(internal_cgo_cancel_before_execute_total) DECLARE_PROMETHEUS_COUNTER(internal_cgo_cancel_before_execute_total_all); +DECLARE_PROMETHEUS_COUNTER_FAMILY(internal_cgo_cancel_during_execute_total); +DECLARE_PROMETHEUS_COUNTER(internal_cgo_cancel_during_execute_total_all); DECLARE_PROMETHEUS_GAUGE_FAMILY(internal_cgo_pool_size); DECLARE_PROMETHEUS_GAUGE(internal_cgo_pool_size_all); DECLARE_PROMETHEUS_GAUGE_FAMILY(internal_cgo_inflight_task_total); diff --git a/internal/core/src/query/ExecPlanNodeVisitor.cpp b/internal/core/src/query/ExecPlanNodeVisitor.cpp index 55d0b5b59a..98983c4934 100644 --- a/internal/core/src/query/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/ExecPlanNodeVisitor.cpp @@ -69,53 +69,6 @@ ExecPlanNodeVisitor::ExecuteTask( return bitset_holder; } -void -ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { - assert(!search_result_opt_.has_value()); - auto segment = - dynamic_cast(&segment_); - AssertInfo(segment, "support SegmentSmallIndex Only"); - - auto active_count = segment->get_active_count(timestamp_); - - // PreExecute: skip all calculation - if (active_count == 0) { - search_result_opt_ = std::move( - empty_search_result(placeholder_group_->at(0).num_of_queries_)); - return; - } - - // Construct plan fragment - auto plan = plan::PlanFragment(node.plannodes_); - - // Set query context - auto query_context = - std::make_shared(DEAFULT_QUERY_ID, - segment, - active_count, - timestamp_, - collection_ttl_timestamp_, - consystency_level_, - node.plan_options_); - - query_context->set_search_info(node.search_info_); - query_context->set_placeholder_group(placeholder_group_); - - // Set op context to query context - auto op_context = milvus::OpContext(); - query_context->set_op_context(&op_context); - - // Do plan fragment task work - auto result = ExecuteTask(plan, query_context); - - // Store result - search_result_opt_ = std::move(query_context->get_search_result()); - search_result_opt_->search_storage_cost_.scanned_remote_bytes = - op_context.storage_usage.scanned_cold_bytes.load(); - search_result_opt_->search_storage_cost_.scanned_total_bytes = - op_context.storage_usage.scanned_total_bytes.load(); -} - std::unique_ptr wrap_num_entities(int64_t cnt) { auto retrieve_result = std::make_unique(); @@ -161,11 +114,11 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { active_count, timestamp_, collection_ttl_timestamp_, - consystency_level_, + consistency_level_, node.plan_options_); // Set op context to query context - auto op_context = milvus::OpContext(); + auto op_context = milvus::OpContext(cancel_token_); query_context->set_op_context(&op_context); // Do task execution @@ -194,7 +147,49 @@ ExecPlanNodeVisitor::visit(RetrievePlanNode& node) { void ExecPlanNodeVisitor::visit(VectorPlanNode& node) { - VectorVisitorImpl(node); + assert(!search_result_opt_.has_value()); + auto segment = + dynamic_cast(&segment_); + AssertInfo(segment, "support SegmentSmallIndex Only"); + + auto active_count = segment->get_active_count(timestamp_); + + // PreExecute: skip all calculation + if (active_count == 0) { + search_result_opt_ = std::move( + empty_search_result(placeholder_group_->at(0).num_of_queries_)); + return; + } + + // Construct plan fragment + auto plan = plan::PlanFragment(node.plannodes_); + + // Set query context + auto query_context = + std::make_shared(DEAFULT_QUERY_ID, + segment, + active_count, + timestamp_, + collection_ttl_timestamp_, + consistency_level_, + node.plan_options_); + + query_context->set_search_info(node.search_info_); + query_context->set_placeholder_group(placeholder_group_); + + // Set op context to query context + auto op_context = milvus::OpContext(cancel_token_); + query_context->set_op_context(&op_context); + + // Do plan fragment task work + auto result = ExecuteTask(plan, query_context); + + // Store result + search_result_opt_ = std::move(query_context->get_search_result()); + search_result_opt_->search_storage_cost_.scanned_remote_bytes = + op_context.storage_usage.scanned_cold_bytes.load(); + search_result_opt_->search_storage_cost_.scanned_total_bytes = + op_context.storage_usage.scanned_total_bytes.load(); } } // namespace milvus::query diff --git a/internal/core/src/query/ExecPlanNodeVisitor.h b/internal/core/src/query/ExecPlanNodeVisitor.h index f24cc1e9f4..a10df74f60 100644 --- a/internal/core/src/query/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/ExecPlanNodeVisitor.h @@ -17,6 +17,7 @@ #include "PlanNodeVisitor.h" #include "plan/PlanNode.h" #include "exec/QueryContext.h" +#include "futures/Future.h" namespace milvus::query { @@ -34,23 +35,30 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, Timestamp timestamp, const PlaceholderGroup* placeholder_group, - int32_t consystency_level = 0, + const folly::CancellationToken& cancel_token = + folly::CancellationToken(), + int32_t consistency_level = 0, Timestamp collection_ttl = 0) : segment_(segment), timestamp_(timestamp), - collection_ttl_timestamp_(collection_ttl), placeholder_group_(placeholder_group), - consystency_level_(consystency_level) { + cancel_token_(cancel_token), + consistency_level_(consistency_level), + collection_ttl_timestamp_(collection_ttl) { } + // Only used for test ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, Timestamp timestamp, - int32_t consystency_level = 0, + const folly::CancellationToken& cancel_token = + folly::CancellationToken(), + int32_t consistency_level = 0, Timestamp collection_ttl = 0) : segment_(segment), timestamp_(timestamp), - collection_ttl_timestamp_(collection_ttl), - consystency_level_(consystency_level) { + cancel_token_(cancel_token), + consistency_level_(consistency_level), + collection_ttl_timestamp_(collection_ttl) { placeholder_group_ = nullptr; } @@ -91,20 +99,18 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { ExecuteTask(plan::PlanFragment& plan, std::shared_ptr query_context); - private: - void - VectorVisitorImpl(VectorPlanNode& node); - private: const segcore::SegmentInterface& segment_; Timestamp timestamp_; - Timestamp collection_ttl_timestamp_; const PlaceholderGroup* placeholder_group_; + folly::CancellationToken cancel_token_; + int32_t consistency_level_ = 0; + Timestamp collection_ttl_timestamp_; SearchResultOpt search_result_opt_; RetrieveResultOpt retrieve_result_opt_; + bool expr_use_pk_index_ = false; - int32_t consystency_level_ = 0; }; // for test use only diff --git a/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp b/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp index 7245292a56..3417cda369 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp @@ -256,7 +256,7 @@ TEST_P(BinlogIndexTest, AccuracyWithLoadFieldData) { ph_group.get()}; auto nlist = segcore_config.get_nlist(); auto binlog_index_sr = - segment->Search(plan.get(), ph_group.get(), 1L << 63, 0); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); ASSERT_EQ(binlog_index_sr->total_nq_, num_queries); EXPECT_EQ(binlog_index_sr->unity_topK_, topk); EXPECT_EQ(binlog_index_sr->distances_.size(), num_queries * topk); @@ -288,7 +288,8 @@ TEST_P(BinlogIndexTest, AccuracyWithLoadFieldData) { ASSERT_NO_THROW(segment->LoadIndex(load_info)); EXPECT_TRUE(segment->HasIndex(vec_field_id)); EXPECT_EQ(segment->get_row_count(), data_n); - auto ivf_sr = segment->Search(plan.get(), ph_group.get(), 1L << 63, 0); + auto ivf_sr = + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); auto similary = GetKnnSearchRecall(num_queries, binlog_index_sr->seg_offsets_.data(), topk, @@ -352,7 +353,7 @@ TEST_P(BinlogIndexTest, AccuracyWithMapFieldData) { ph_group.get()}; auto nlist = segcore_config.get_nlist(); auto binlog_index_sr = - segment->Search(plan.get(), ph_group.get(), 1L << 63, 0); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); ASSERT_EQ(binlog_index_sr->total_nq_, num_queries); EXPECT_EQ(binlog_index_sr->unity_topK_, topk); EXPECT_EQ(binlog_index_sr->distances_.size(), num_queries * topk); @@ -384,7 +385,8 @@ TEST_P(BinlogIndexTest, AccuracyWithMapFieldData) { ASSERT_NO_THROW(segment->LoadIndex(load_info)); EXPECT_TRUE(segment->HasIndex(vec_field_id)); EXPECT_EQ(segment->get_row_count(), data_n); - auto ivf_sr = segment->Search(plan.get(), ph_group.get(), 1L << 63); + auto ivf_sr = + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); auto similary = GetKnnSearchRecall(num_queries, binlog_index_sr->seg_offsets_.data(), topk, diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index f2bd11de1c..6242cb6163 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -621,8 +621,7 @@ ChunkedSegmentSealedImpl::chunk_array_view_impl( milvus::OpContext* op_ctx, FieldId field_id, int64_t chunk_id, - std::optional> offset_len = - std::nullopt) const { + std::optional> offset_len) const { std::shared_lock lck(mutex_); AssertInfo(get_bit(field_data_ready_bitset_, field_id), "Can't get bitset element at " + std::to_string(field_id.get())); @@ -638,8 +637,7 @@ ChunkedSegmentSealedImpl::chunk_vector_array_view_impl( milvus::OpContext* op_ctx, FieldId field_id, int64_t chunk_id, - std::optional> offset_len = - std::nullopt) const { + std::optional> offset_len) const { std::shared_lock lck(mutex_); AssertInfo(get_bit(field_data_ready_bitset_, field_id), "Can't get bitset element at " + std::to_string(field_id.get())); @@ -655,8 +653,7 @@ ChunkedSegmentSealedImpl::chunk_string_view_impl( milvus::OpContext* op_ctx, FieldId field_id, int64_t chunk_id, - std::optional> offset_len = - std::nullopt) const { + std::optional> offset_len) const { std::shared_lock lck(mutex_); AssertInfo(get_bit(field_data_ready_bitset_, field_id), "Can't get bitset element at " + std::to_string(field_id.get())); diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index a6980e3944..ba970d64a8 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -617,8 +617,7 @@ SegmentGrowingImpl::chunk_string_view_impl( milvus::OpContext* op_ctx, FieldId field_id, int64_t chunk_id, - std::optional> offset_len = - std::nullopt) const { + std::optional> offset_len) const { ThrowInfo(ErrorCode::NotImplemented, "chunk string view impl not implement for growing segment"); } @@ -628,8 +627,7 @@ SegmentGrowingImpl::chunk_array_view_impl( milvus::OpContext* op_ctx, FieldId field_id, int64_t chunk_id, - std::optional> offset_len = - std::nullopt) const { + std::optional> offset_len) const { ThrowInfo(ErrorCode::NotImplemented, "chunk array view impl not implement for growing segment"); } @@ -639,8 +637,7 @@ SegmentGrowingImpl::chunk_vector_array_view_impl( milvus::OpContext* op_ctx, FieldId field_id, int64_t chunk_id, - std::optional> offset_len = - std::nullopt) const { + std::optional> offset_len) const { ThrowInfo(ErrorCode::NotImplemented, "chunk vector array view impl not implement for growing segment"); } diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index a39baf950c..f462c0ca5c 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -166,7 +166,7 @@ class SegmentGrowingImpl : public SegmentGrowing { return segcore_config_.get_chunk_rows(); } - virtual int64_t + int64_t chunk_size(FieldId field_id, int64_t chunk_id) const final { return segcore_config_.get_chunk_rows(); } @@ -376,7 +376,7 @@ class SegmentGrowingImpl : public SegmentGrowing { search_ids(BitsetType& bitset, const IdArray& id_array) const override; bool - HasIndex(FieldId field_id) const { + HasIndex(FieldId field_id) const override { auto& field_meta = schema_->operator[](field_id); if ((IsVectorDataType(field_meta.get_data_type()) || IsGeometryType(field_meta.get_data_type())) && diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index da7425e2ae..d6fd4f66b4 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -20,6 +20,7 @@ #include "common/Types.h" #include "monitor/Monitor.h" #include "query/ExecPlanNodeVisitor.h" +#include "futures/Future.h" namespace milvus::segcore { @@ -95,13 +96,18 @@ SegmentInternalInterface::Search( const query::Plan* plan, const query::PlaceholderGroup* placeholder_group, Timestamp timestamp, + const folly::CancellationToken& cancel_token, int32_t consistency_level, Timestamp collection_ttl) const { std::shared_lock lck(mutex_); milvus::tracer::AddEvent("obtained_segment_lock_mutex"); check_search(plan); - query::ExecPlanNodeVisitor visitor( - *this, timestamp, placeholder_group, consistency_level, collection_ttl); + query::ExecPlanNodeVisitor visitor(*this, + timestamp, + placeholder_group, + cancel_token, + consistency_level, + collection_ttl); auto results = std::make_unique(); *results = visitor.get_moved_result(*plan->plan_node_); results->segment_ = (void*)this; @@ -114,13 +120,14 @@ SegmentInternalInterface::Retrieve(tracer::TraceContext* trace_ctx, Timestamp timestamp, int64_t limit_size, bool ignore_non_pk, + const folly::CancellationToken& cancel_token, int32_t consistency_level, Timestamp collection_ttl) const { std::shared_lock lck(mutex_); tracer::AutoSpan span("Retrieve", tracer::GetRootSpan()); auto results = std::make_unique(); query::ExecPlanNodeVisitor visitor( - *this, timestamp, consistency_level, collection_ttl); + *this, timestamp, cancel_token, consistency_level, collection_ttl); auto retrieve_results = visitor.get_retrieve_result(*plan->plan_node_); retrieve_results.segment_ = (void*)this; results->set_has_more_result(retrieve_results.has_more_result); @@ -167,6 +174,7 @@ SegmentInternalInterface::Retrieve(tracer::TraceContext* trace_ctx, .count(); milvus::monitor::internal_core_retrieve_get_target_entry_latency.Observe( get_entry_cost / 1000); + milvus::futures::throwIfCancelled(cancel_token); return results; } @@ -323,8 +331,14 @@ SegmentInternalInterface::get_real_count() const { milvus::plan::GetNextPlanNodeId(), sources); plan->plan_node_->plannodes_ = plannode; plan->plan_node_->is_count_ = true; - auto res = - Retrieve(nullptr, plan.get(), MAX_TIMESTAMP, INT64_MAX, false, 0); + auto res = Retrieve(nullptr, + plan.get(), + MAX_TIMESTAMP, + INT64_MAX, + false, + folly::CancellationToken(), + 0, + 0); AssertInfo(res->fields_data().size() == 1, "count result should only have one column"); AssertInfo(res->fields_data()[0].has_scalars(), diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index ced32bedaa..5a483b26fc 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -11,6 +11,9 @@ #pragma once +#ifndef MILVUS_SEGCORE_SEGMENT_INTERFACE_H_ +#define MILVUS_SEGCORE_SEGMENT_INTERFACE_H_ + #include #include #include @@ -74,8 +77,22 @@ class SegmentInterface { Search(const query::Plan* Plan, const query::PlaceholderGroup* placeholder_group, Timestamp timestamp, - int32_t consistency_level = 0, - Timestamp collection_ttl = 0) const = 0; + const folly::CancellationToken& cancel_token, + int32_t consistency_level, + Timestamp collection_ttl) const = 0; + + // Only used for test + std::unique_ptr + Search(const query::Plan* Plan, + const query::PlaceholderGroup* placeholder_group, + Timestamp timestamp) const { + return Search(Plan, + placeholder_group, + timestamp, + folly::CancellationToken(), + 0, + 0); + } virtual std::unique_ptr Retrieve(tracer::TraceContext* trace_ctx, @@ -83,8 +100,26 @@ class SegmentInterface { Timestamp timestamp, int64_t limit_size, bool ignore_non_pk, - int32_t consistency_level = 0, - Timestamp collection_ttl = 0) const = 0; + const folly::CancellationToken& cancel_token, + int32_t consistency_level, + Timestamp collection_ttl) const = 0; + + // Only used for test + std::unique_ptr + Retrieve(tracer::TraceContext* trace_ctx, + const query::RetrievePlan* Plan, + Timestamp timestamp, + int64_t limit_size, + bool ignore_non_pk) const { + return Retrieve(trace_ctx, + Plan, + timestamp, + limit_size, + ignore_non_pk, + folly::CancellationToken(), + 0, + 0); + } virtual std::unique_ptr Retrieve(tracer::TraceContext* trace_ctx, @@ -158,10 +193,15 @@ class SegmentInterface { virtual std::vector> PinIndex(milvus::OpContext* op_ctx, FieldId field_id, - bool include_ngram = false) const { + bool include_ngram) const { return {}; }; + std::vector> + PinIndex(milvus::OpContext* op_ctx, FieldId field_id) const { + return PinIndex(op_ctx, field_id, false); + } + virtual void BulkGetJsonData(milvus::OpContext* op_ctx, FieldId field_id, @@ -304,12 +344,16 @@ class SegmentInternalInterface : public SegmentInterface { std::to_string(field_id); } + // Bring in base class Search overloads to avoid name hiding + using SegmentInterface::Search; + std::unique_ptr Search(const query::Plan* Plan, const query::PlaceholderGroup* placeholder_group, Timestamp timestamp, - int32_t consistency_level = 0, - Timestamp collection_ttl = 0) const override; + const folly::CancellationToken& cancel_token, + int32_t consistency_level, + Timestamp collection_ttl) const override; void FillPrimaryKeys(const query::Plan* plan, @@ -319,14 +363,18 @@ class SegmentInternalInterface : public SegmentInterface { FillTargetEntry(const query::Plan* plan, SearchResult& results) const override; + // Bring in base class Retrieve overloads to avoid name hiding + using SegmentInterface::Retrieve; + std::unique_ptr Retrieve(tracer::TraceContext* trace_ctx, const query::RetrievePlan* Plan, Timestamp timestamp, int64_t limit_size, bool ignore_non_pk, - int32_t consistency_level = 0, - Timestamp collection_ttl = 0) const override; + const folly::CancellationToken& cancel_token, + int32_t consistency_level, + Timestamp collection_ttl) const override; std::unique_ptr Retrieve(tracer::TraceContext* trace_ctx, @@ -368,10 +416,10 @@ class SegmentInternalInterface : public SegmentInterface { PinWrapper GetTextIndex(milvus::OpContext* op_ctx, FieldId field_id) const override; - virtual PinWrapper + PinWrapper GetNgramIndex(milvus::OpContext* op_ctx, FieldId field_id) const override; - virtual PinWrapper + PinWrapper GetNgramIndexForJson(milvus::OpContext* op_ctx, FieldId field_id, const std::string& nested_path) const override; @@ -507,26 +555,26 @@ class SegmentInternalInterface : public SegmentInterface { // internal API: return chunk string views in vector virtual PinWrapper< std::pair, FixedVector>> - chunk_string_view_impl(milvus::OpContext* op_ctx, - FieldId field_id, - int64_t chunk_id, - std::optional> - offset_len = std::nullopt) const = 0; + chunk_string_view_impl( + milvus::OpContext* op_ctx, + FieldId field_id, + int64_t chunk_id, + std::optional> offset_len) const = 0; virtual PinWrapper, FixedVector>> - chunk_array_view_impl(milvus::OpContext* op_ctx, - FieldId field_id, - int64_t chunk_id, - std::optional> - offset_len = std::nullopt) const = 0; + chunk_array_view_impl( + milvus::OpContext* op_ctx, + FieldId field_id, + int64_t chunk_id, + std::optional> offset_len) const = 0; virtual PinWrapper< std::pair, FixedVector>> - chunk_vector_array_view_impl(milvus::OpContext* op_ctx, - FieldId field_id, - int64_t chunk_id, - std::optional> - offset_len = std::nullopt) const = 0; + chunk_vector_array_view_impl( + milvus::OpContext* op_ctx, + FieldId field_id, + int64_t chunk_id, + std::optional> offset_len) const = 0; virtual PinWrapper< std::pair, FixedVector>> @@ -621,3 +669,5 @@ class SegmentInternalInterface : public SegmentInterface { }; } // namespace milvus::segcore + +#endif // MILVUS_SEGCORE_SEGMENT_INTERFACE_H_ diff --git a/internal/core/src/segcore/reduce_c_test.cpp b/internal/core/src/segcore/reduce_c_test.cpp index 273708b81a..0c1aeb87ab 100644 --- a/internal/core/src/segcore/reduce_c_test.cpp +++ b/internal/core/src/segcore/reduce_c_test.cpp @@ -80,7 +80,7 @@ TEST(CApiTest, ReduceNullResult) { auto slice_topKs = std::vector{1}; std::vector results; CSearchResult res; - status = CSearch(segment, plan, placeholderGroup, 1L << 63, &res); + status = CSearch(segment, plan, placeholderGroup, MAX_TIMESTAMP, &res); ASSERT_EQ(status.error_code, Success); results.push_back(res); CSearchResultDataBlobs cSearchResultData; diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index d87004c01a..be7a5d48c5 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -114,7 +114,7 @@ DeleteSearchResult(CSearchResult search_result) { delete res; } -CFuture* // Future +CFuture* // Future AsyncSearch(CTraceContext c_trace, CSegmentInterface c_segment, CSearchPlan c_plan, @@ -122,8 +122,8 @@ AsyncSearch(CTraceContext c_trace, uint64_t timestamp, int32_t consistency_level, uint64_t collection_ttl) { - auto segment = (milvus::segcore::SegmentInterface*)c_segment; - auto plan = (milvus::query::Plan*)c_plan; + auto segment = static_cast(c_segment); + auto plan = static_cast(c_plan); auto phg_ptr = reinterpret_cast( c_placeholder_group); @@ -136,7 +136,7 @@ AsyncSearch(CTraceContext c_trace, phg_ptr, timestamp, consistency_level, - collection_ttl](milvus::futures::CancellationToken cancel_token) { + collection_ttl](folly::CancellationToken cancel_token) { // save trace context into search_info auto& trace_ctx = plan->plan_node_->search_info_.trace_ctx_; trace_ctx.traceID = c_trace.traceID; @@ -148,8 +148,12 @@ AsyncSearch(CTraceContext c_trace, segment->LazyCheckSchema(plan->schema_); - auto search_result = segment->Search( - plan, phg_ptr, timestamp, consistency_level, collection_ttl); + auto search_result = segment->Search(plan, + phg_ptr, + timestamp, + cancel_token, + consistency_level, + collection_ttl); if (!milvus::PositivelyRelated( plan->plan_node_->search_info_.metric_type_)) { for (auto& dis : search_result->distances_) { @@ -212,7 +216,7 @@ AsyncRetrieve(CTraceContext c_trace, limit_size, ignore_non_pk, consistency_level, - collection_ttl](milvus::futures::CancellationToken cancel_token) { + collection_ttl](folly::CancellationToken cancel_token) { auto trace_ctx = milvus::tracer::TraceContext{ c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; milvus::tracer::AutoSpan span("SegCoreRetrieve", &trace_ctx, true); @@ -224,6 +228,7 @@ AsyncRetrieve(CTraceContext c_trace, timestamp, limit_size, ignore_non_pk, + cancel_token, consistency_level, collection_ttl); @@ -247,7 +252,7 @@ AsyncRetrieveByOffsets(CTraceContext c_trace, milvus::futures::getGlobalCPUExecutor(), milvus::futures::ExecutePriority::HIGH, [c_trace, segment, plan, offsets, len]( - milvus::futures::CancellationToken cancel_token) { + folly::CancellationToken cancel_token) { auto trace_ctx = milvus::tracer::TraceContext{ c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; milvus::tracer::AutoSpan span( diff --git a/internal/core/thirdparty/milvus-common/CMakeLists.txt b/internal/core/thirdparty/milvus-common/CMakeLists.txt index 5c1be4adc8..b153cf819c 100644 --- a/internal/core/thirdparty/milvus-common/CMakeLists.txt +++ b/internal/core/thirdparty/milvus-common/CMakeLists.txt @@ -13,7 +13,7 @@ milvus_add_pkg_config("milvus-common") set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "") -set( MILVUS-COMMON-VERSION 3501a0c ) +set( MILVUS-COMMON-VERSION 951a806 ) set( GIT_REPOSITORY "https://github.com/zilliztech/milvus-common.git") message(STATUS "milvus-common repo: ${GIT_REPOSITORY}") diff --git a/internal/core/unittest/bench/bench_search.cpp b/internal/core/unittest/bench/bench_search.cpp index fd23421e3a..f2c18cfca1 100644 --- a/internal/core/unittest/bench/bench_search.cpp +++ b/internal/core/unittest/bench/bench_search.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "common/type_c.h" #include "segcore/segment_c.h" #include "segcore/SegmentGrowing.h" @@ -95,9 +96,11 @@ Search_GrowingIndex(benchmark::State& state) { dataset_.raw_); Timestamp ts = 10000000; + folly::CancellationToken token; for (auto _ : state) { - auto qr = segment->Search(search_plan.get(), ph_group.get(), ts, 0); + auto qr = + segment->Search(search_plan.get(), ph_group.get(), ts, token, 0, 0); } } @@ -137,9 +140,11 @@ Search_Sealed(benchmark::State& state) { } Timestamp ts = 10000000; + folly::CancellationToken token; for (auto _ : state) { - auto qr = segment->Search(search_plan.get(), ph_group.get(), ts, 0); + auto qr = + segment->Search(search_plan.get(), ph_group.get(), ts, token, 0, 0); } } diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp index 6cec29cf97..05cdfadb99 100644 --- a/internal/core/unittest/test_float16.cpp +++ b/internal/core/unittest/test_float16.cpp @@ -280,7 +280,7 @@ TEST(Float16, ExecWithPredicate) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto sr = segment->Search(plan.get(), ph_group.get(), 1L << 63); + auto sr = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); int topk = 5; query::Json json = SearchResultToJson(*sr); diff --git a/internal/core/unittest/test_group_by.cpp b/internal/core/unittest/test_group_by.cpp index 4f56ffc272..a7e2774bc7 100644 --- a/internal/core/unittest/test_group_by.cpp +++ b/internal/core/unittest/test_group_by.cpp @@ -93,7 +93,7 @@ TEST(GroupBY, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); @@ -146,7 +146,7 @@ TEST(GroupBY, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); @@ -196,7 +196,7 @@ TEST(GroupBY, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); @@ -247,7 +247,7 @@ TEST(GroupBY, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); int size = group_by_values.size(); @@ -297,7 +297,7 @@ TEST(GroupBY, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); ASSERT_EQ(20, group_by_values.size()); @@ -347,7 +347,7 @@ TEST(GroupBY, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); @@ -398,7 +398,7 @@ TEST(GroupBY, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); int size = group_by_values.size(); @@ -471,7 +471,7 @@ TEST(GroupBY, SealedData) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); @@ -584,11 +584,11 @@ TEST(GroupBY, Reduce) { CSearchResult c_search_res_1; CSearchResult c_search_res_2; - auto status = - CSearch(c_segment_1, c_plan, c_ph_group, 1L << 63, &c_search_res_1); + auto status = CSearch( + c_segment_1, c_plan, c_ph_group, MAX_TIMESTAMP, &c_search_res_1); ASSERT_EQ(status.error_code, Success); - status = - CSearch(c_segment_2, c_plan, c_ph_group, 1L << 63, &c_search_res_2); + status = CSearch( + c_segment_2, c_plan, c_ph_group, MAX_TIMESTAMP, &c_search_res_2); ASSERT_EQ(status.error_code, Success); std::vector results; results.push_back(c_search_res_1); @@ -756,7 +756,7 @@ TEST(GroupBY, GrowingRawData) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment_growing_impl->Search(plan.get(), ph_group.get(), 1L << 63); + segment_growing_impl->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, true); auto& group_by_values = search_result->group_by_values_.value(); @@ -856,7 +856,7 @@ TEST(GroupBY, GrowingIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment_growing_impl->Search(plan.get(), ph_group.get(), 1L << 63); + segment_growing_impl->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, true); auto& group_by_values = search_result->group_by_values_.value(); diff --git a/internal/core/unittest/test_group_by_json.cpp b/internal/core/unittest/test_group_by_json.cpp index 456cb2cc57..a939cab4c8 100644 --- a/internal/core/unittest/test_group_by_json.cpp +++ b/internal/core/unittest/test_group_by_json.cpp @@ -24,7 +24,8 @@ run_group_by_search(const std::string& raw_plan, auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = segment->Search(plan.get(), ph_group.get(), 1L << 63); + auto search_result = + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); CheckGroupBySearchResult(*search_result, topK, num_queries, false); return search_result; } @@ -597,11 +598,11 @@ TEST(GroupBYJSON, Reduce) { CSearchResult c_search_res_1; CSearchResult c_search_res_2; - auto status = - CSearch(c_segment_1, c_plan, c_ph_group, 1L << 63, &c_search_res_1); + auto status = CSearch( + c_segment_1, c_plan, c_ph_group, MAX_TIMESTAMP, &c_search_res_1); ASSERT_EQ(status.error_code, Success); - status = - CSearch(c_segment_2, c_plan, c_ph_group, 1L << 63, &c_search_res_2); + status = CSearch( + c_segment_2, c_plan, c_ph_group, MAX_TIMESTAMP, &c_search_res_2); ASSERT_EQ(status.error_code, Success); std::vector results; results.push_back(c_search_res_1); diff --git a/internal/core/unittest/test_iterative_filter.cpp b/internal/core/unittest/test_iterative_filter.cpp index b8aa97f390..bd6df50f0a 100644 --- a/internal/core/unittest/test_iterative_filter.cpp +++ b/internal/core/unittest/test_iterative_filter.cpp @@ -124,7 +124,7 @@ TEST(IterativeFilter, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); const char* raw_plan2 = R"(vector_anns: < field_id: 100 @@ -155,7 +155,7 @@ TEST(IterativeFilter, SealedIndex) { &plan_node2); auto plan2 = CreateSearchPlanFromPlanNode(schema, plan_node2); auto search_result2 = - segment->Search(plan2.get(), ph_group.get(), 1L << 63); + segment->Search(plan2.get(), ph_group.get(), MAX_TIMESTAMP); CheckFilterSearchResult( *search_result, *search_result2, topK, num_queries); } @@ -190,7 +190,7 @@ TEST(IterativeFilter, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); const char* raw_plan2 = R"(vector_anns: < field_id: 100 @@ -214,7 +214,7 @@ TEST(IterativeFilter, SealedIndex) { &plan_node2); auto plan2 = CreateSearchPlanFromPlanNode(schema, plan_node2); auto search_result2 = - segment->Search(plan2.get(), ph_group.get(), 1L << 63); + segment->Search(plan2.get(), ph_group.get(), MAX_TIMESTAMP); CheckFilterSearchResult( *search_result, *search_result2, topK, num_queries); } @@ -240,7 +240,7 @@ TEST(IterativeFilter, SealedIndex) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); const char* raw_plan2 = R"(vector_anns: < field_id: 100 @@ -255,7 +255,7 @@ TEST(IterativeFilter, SealedIndex) { &plan_node2); auto plan2 = CreateSearchPlanFromPlanNode(schema, plan_node2); auto search_result2 = - segment->Search(plan2.get(), ph_group.get(), 1L << 63); + segment->Search(plan2.get(), ph_group.get(), MAX_TIMESTAMP); CheckFilterSearchResult( *search_result, *search_result2, topK, num_queries); } @@ -322,7 +322,7 @@ TEST(IterativeFilter, SealedData) { auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); auto search_result = - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); const char* raw_plan2 = R"(vector_anns: < field_id: 100 @@ -353,7 +353,7 @@ TEST(IterativeFilter, SealedData) { &plan_node2); auto plan2 = CreateSearchPlanFromPlanNode(schema, plan_node2); auto search_result2 = - segment->Search(plan2.get(), ph_group.get(), 1L << 63); + segment->Search(plan2.get(), ph_group.get(), MAX_TIMESTAMP); CheckFilterSearchResult( *search_result, *search_result2, topK, num_queries); } @@ -428,8 +428,8 @@ TEST(IterativeFilter, GrowingRawData) { auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = - segment_growing_impl->Search(plan.get(), ph_group.get(), 1L << 63); + auto search_result = segment_growing_impl->Search( + plan.get(), ph_group.get(), MAX_TIMESTAMP); const char* raw_plan2 = R"(vector_anns: < field_id: 102 @@ -459,8 +459,8 @@ TEST(IterativeFilter, GrowingRawData) { auto ok2 = google::protobuf::TextFormat::ParseFromString(raw_plan2, &plan_node2); auto plan2 = CreateSearchPlanFromPlanNode(schema, plan_node2); - auto search_result2 = - segment_growing_impl->Search(plan2.get(), ph_group.get(), 1L << 63); + auto search_result2 = segment_growing_impl->Search( + plan2.get(), ph_group.get(), MAX_TIMESTAMP); CheckFilterSearchResult( *search_result, *search_result2, topK, num_queries); } @@ -548,8 +548,8 @@ TEST(IterativeFilter, GrowingIndex) { auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); - auto search_result = - segment_growing_impl->Search(plan.get(), ph_group.get(), 1L << 63); + auto search_result = segment_growing_impl->Search( + plan.get(), ph_group.get(), MAX_TIMESTAMP); const char* raw_plan2 = R"(vector_anns: < field_id: 102 @@ -579,8 +579,8 @@ TEST(IterativeFilter, GrowingIndex) { auto ok2 = google::protobuf::TextFormat::ParseFromString(raw_plan2, &plan_node2); auto plan2 = CreateSearchPlanFromPlanNode(schema, plan_node2); - auto search_result2 = - segment_growing_impl->Search(plan2.get(), ph_group.get(), 1L << 63); + auto search_result2 = segment_growing_impl->Search( + plan2.get(), ph_group.get(), MAX_TIMESTAMP); CheckFilterSearchResult( *search_result, *search_result2, topK, num_queries); } diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index f607093b16..6909cb46ff 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -151,7 +151,12 @@ TEST(Sealed, without_predicate) { sr = sealed_segment->Search(plan.get(), ph_group.get(), 0); EXPECT_EQ(sr->get_total_result_count(), 0); - sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp, 0, 100); + sr = sealed_segment->Search(plan.get(), + ph_group.get(), + timestamp, + folly::CancellationToken(), + 0, + 100); EXPECT_EQ(sr->get_total_result_count(), 0); } @@ -989,7 +994,12 @@ TEST(Sealed, LoadScalarIndex) { nothing_index.cache_index = CreateTestCacheIndex("test", std::move(temp2)); segment->LoadIndex(nothing_index); - auto sr = segment->Search(plan.get(), ph_group.get(), timestamp, 0, 100000); + auto sr = segment->Search(plan.get(), + ph_group.get(), + timestamp, + folly::CancellationToken(), + 0, + 100000); auto json = SearchResultToJson(*sr); std::cout << json.dump(1); } @@ -1505,7 +1515,7 @@ TEST(Sealed, LoadArrayFieldData) { ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); segment = CreateSealedWithFieldDataLoaded(schema, dataset); - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); auto ids_ds = GenRandomIds(N); auto s = dynamic_cast(segment.get()); @@ -1563,7 +1573,7 @@ TEST(Sealed, LoadArrayFieldDataWithMMap) { ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); segment = CreateSealedWithFieldDataLoaded(schema, dataset, true); - segment->Search(plan.get(), ph_group.get(), 1L << 63); + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); } TEST(Sealed, SkipIndexSkipUnaryRange) {