diff --git a/internal/core/src/futures/Executor.cpp b/internal/core/src/futures/Executor.cpp index e202aad2dc..b424809e0a 100644 --- a/internal/core/src/futures/Executor.cpp +++ b/internal/core/src/futures/Executor.cpp @@ -16,30 +16,14 @@ namespace milvus::futures { const int kNumPriority = 3; -const int kMaxQueueSizeFactor = 16; -folly::Executor::KeepAlive<> +folly::CPUThreadPoolExecutor* getGlobalCPUExecutor() { - static ExecutorSingleton singleton; - return singleton.GetCPUExecutor(); -} - -folly::Executor::KeepAlive<> -ExecutorSingleton::GetCPUExecutor() { - // TODO: fix the executor with a non-block way. - std::call_once(cpu_executor_once_, [this]() { - int num_threads = milvus::CPU_NUM; - auto num_priority = kNumPriority; - auto max_queue_size = num_threads * kMaxQueueSizeFactor; - cpu_executor_ = std::make_unique( - num_threads, - std::make_unique>(num_priority, - max_queue_size), - std::make_shared("MILVUS_CPU_")); - }); - return folly::getKeepAliveToken(cpu_executor_.get()); + static folly::CPUThreadPoolExecutor executor( + std::thread::hardware_concurrency(), + folly::CPUThreadPoolExecutor::makeDefaultPriorityQueue(kNumPriority), + std::make_shared("MILVUS_FUTURE_CPU_")); + return &executor; } }; // namespace milvus::futures \ No newline at end of file diff --git a/internal/core/src/futures/Executor.h b/internal/core/src/futures/Executor.h index c1579009df..5adfe389b3 100644 --- a/internal/core/src/futures/Executor.h +++ b/internal/core/src/futures/Executor.h @@ -18,23 +18,13 @@ namespace milvus::futures { -folly::Executor::KeepAlive<> +namespace ExecutePriority { +const int LOW = 2; +const int NORMAL = 1; +const int HIGH = 0; +} // namespace ExecutePriority + +folly::CPUThreadPoolExecutor* getGlobalCPUExecutor(); -class ExecutorSingleton { - public: - ExecutorSingleton() = default; - - ExecutorSingleton(const ExecutorSingleton&) = delete; - - ExecutorSingleton(ExecutorSingleton&&) noexcept = delete; - - folly::Executor::KeepAlive<> - GetCPUExecutor(); - - private: - std::unique_ptr cpu_executor_; - std::once_flag cpu_executor_once_; -}; - }; // namespace milvus::futures diff --git a/internal/core/src/futures/Future.h b/internal/core/src/futures/Future.h index 5a81af5eca..60eb804e96 100644 --- a/internal/core/src/futures/Future.h +++ b/internal/core/src/futures/Future.h @@ -16,7 +16,6 @@ #include #include #include - #include "future_c_types.h" #include "LeakyResult.h" #include "Ready.h" @@ -56,6 +55,8 @@ class IFuture { releaseLeakedFuture(IFuture* future) { delete future; } + + virtual ~IFuture() = default; }; /// @brief a class that represents a cancellation token @@ -176,6 +177,7 @@ class Future : public IFuture { CancellationToken(cancellation_source_.getToken()); auto runner = [fn = std::forward(fn), cancellation_token = std::move(cancellation_token)]() { + cancellation_token.throwIfCancelled(); return fn(cancellation_token); }; diff --git a/internal/core/src/futures/future_c.cpp b/internal/core/src/futures/future_c.cpp index 8d5159cff4..1221d2d653 100644 --- a/internal/core/src/futures/future_c.cpp +++ b/internal/core/src/futures/future_c.cpp @@ -14,6 +14,8 @@ #include "future_c.h" #include "folly/init/Init.h" #include "Future.h" +#include "Executor.h" +#include "log/Log.h" extern "C" void future_cancel(CFuture* future) { @@ -48,4 +50,11 @@ extern "C" void future_destroy(CFuture* future) { milvus::futures::IFuture::releaseLeakedFuture( static_cast(static_cast(future))); -} \ No newline at end of file +} + +extern "C" void +executor_set_thread_num(int thread_num) { + milvus::futures::getGlobalCPUExecutor()->setNumThreads(thread_num); + LOG_INFO("future executor setup cpu executor with thread num: {}", + thread_num); +} diff --git a/internal/core/src/futures/future_c.h b/internal/core/src/futures/future_c.h index 392528c98b..539f22eff1 100644 --- a/internal/core/src/futures/future_c.h +++ b/internal/core/src/futures/future_c.h @@ -39,6 +39,9 @@ future_create_test_case(int interval, int loop_cnt, int caseNo); void future_destroy(CFuture* future); +void +executor_set_thread_num(int thread_num); + #ifdef __cplusplus } #endif diff --git a/internal/core/src/futures/future_test_case_c.cpp b/internal/core/src/futures/future_test_case_c.cpp index 9c9b3359cc..bdf7ccf130 100644 --- a/internal/core/src/futures/future_test_case_c.cpp +++ b/internal/core/src/futures/future_test_case_c.cpp @@ -16,7 +16,7 @@ extern "C" CFuture* future_create_test_case(int interval, int loop_cnt, int case_no) { auto future = milvus::futures::Future::async( milvus::futures::getGlobalCPUExecutor(), - 0, + milvus::futures::ExecutePriority::HIGH, [interval = interval, loop_cnt = loop_cnt, case_no = case_no]( milvus::futures::CancellationToken token) { for (int i = 0; i < loop_cnt; i++) { diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index a386e83c86..7d7944eda9 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -42,6 +42,6 @@ set(SEGCORE_FILES check_vec_index_c.cpp) add_library(milvus_segcore SHARED ${SEGCORE_FILES}) -target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage) +target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage milvus_futures) install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index 06643ea3f7..e662c22181 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -27,6 +27,8 @@ #include "segcore/SegmentSealedImpl.h" #include "segcore/Utils.h" #include "storage/Util.h" +#include "futures/Future.h" +#include "futures/Executor.h" #include "storage/space.h" ////////////////////////////// common interfaces ////////////////////////////// @@ -82,113 +84,129 @@ DeleteSearchResult(CSearchResult search_result) { delete res; } -CStatus -Search(CTraceContext c_trace, - CSegmentInterface c_segment, - CSearchPlan c_plan, - CPlaceholderGroup c_placeholder_group, - uint64_t timestamp, - CSearchResult* result) { - try { - auto segment = (milvus::segcore::SegmentInterface*)c_segment; - auto plan = (milvus::query::Plan*)c_plan; - auto phg_ptr = reinterpret_cast( - c_placeholder_group); +CFuture* // Future +AsyncSearch(CTraceContext c_trace, + CSegmentInterface c_segment, + CSearchPlan c_plan, + CPlaceholderGroup c_placeholder_group, + uint64_t timestamp) { + auto segment = (milvus::segcore::SegmentInterface*)c_segment; + auto plan = (milvus::query::Plan*)c_plan; + auto phg_ptr = reinterpret_cast( + c_placeholder_group); - // save trace context into search_info - auto& trace_ctx = plan->plan_node_->search_info_.trace_ctx_; - trace_ctx.traceID = c_trace.traceID; - trace_ctx.spanID = c_trace.spanID; - trace_ctx.traceFlags = c_trace.traceFlags; + auto future = milvus::futures::Future::async( + milvus::futures::getGlobalCPUExecutor(), + milvus::futures::ExecutePriority::HIGH, + [c_trace, segment, plan, phg_ptr, timestamp]( + milvus::futures::CancellationToken cancel_token) { + // save trace context into search_info + auto& trace_ctx = plan->plan_node_->search_info_.trace_ctx_; + trace_ctx.traceID = c_trace.traceID; + trace_ctx.spanID = c_trace.spanID; + trace_ctx.traceFlags = c_trace.traceFlags; - auto span = milvus::tracer::StartSpan("SegCoreSearch", &trace_ctx); - milvus::tracer::SetRootSpan(span); + auto span = milvus::tracer::StartSpan("SegCoreSearch", &trace_ctx); + milvus::tracer::SetRootSpan(span); - auto search_result = segment->Search(plan, phg_ptr, timestamp); - if (!milvus::PositivelyRelated( - plan->plan_node_->search_info_.metric_type_)) { - for (auto& dis : search_result->distances_) { - dis *= -1; + auto search_result = segment->Search(plan, phg_ptr, timestamp); + if (!milvus::PositivelyRelated( + plan->plan_node_->search_info_.metric_type_)) { + for (auto& dis : search_result->distances_) { + dis *= -1; + } } - } - *result = search_result.release(); - span->End(); - milvus::tracer::CloseRootSpan(); - return milvus::SuccessCStatus(); - } catch (std::exception& e) { - return milvus::FailureCStatus(&e); - } + span->End(); + milvus::tracer::CloseRootSpan(); + return search_result.release(); + }); + return static_cast(static_cast( + static_cast(future.release()))); } void DeleteRetrieveResult(CRetrieveResult* retrieve_result) { - std::free(const_cast(retrieve_result->proto_blob)); + delete[] static_cast( + const_cast(retrieve_result->proto_blob)); + delete retrieve_result; } -CStatus -Retrieve(CTraceContext c_trace, - CSegmentInterface c_segment, - CRetrievePlan c_plan, - uint64_t timestamp, - CRetrieveResult* result, - int64_t limit_size, - bool ignore_non_pk) { +/// Create a leaked CRetrieveResult from a proto. +/// Should be released by DeleteRetrieveResult. +CRetrieveResult* +CreateLeakedCRetrieveResultFromProto( + std::unique_ptr retrieve_result) { + auto size = retrieve_result->ByteSizeLong(); + auto buffer = new uint8_t[size]; try { - auto segment = - static_cast(c_segment); - auto plan = static_cast(c_plan); - - auto trace_ctx = milvus::tracer::TraceContext{ - c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; - milvus::tracer::AutoSpan span("SegCoreRetrieve", &trace_ctx, true); - - auto retrieve_result = segment->Retrieve( - &trace_ctx, plan, timestamp, limit_size, ignore_non_pk); - - auto size = retrieve_result->ByteSizeLong(); - std::unique_ptr buffer(new uint8_t[size]); - retrieve_result->SerializePartialToArray(buffer.get(), size); - - result->proto_blob = buffer.release(); - result->proto_size = size; - - return milvus::SuccessCStatus(); + retrieve_result->SerializePartialToArray(buffer, size); } catch (std::exception& e) { - return milvus::FailureCStatus(&e); + delete[] buffer; + throw; } + + auto result = new CRetrieveResult(); + result->proto_blob = buffer; + result->proto_size = size; + return result; } -CStatus -RetrieveByOffsets(CTraceContext c_trace, - CSegmentInterface c_segment, - CRetrievePlan c_plan, - CRetrieveResult* result, - int64_t* offsets, - int64_t len) { - try { - auto segment = - static_cast(c_segment); - auto plan = static_cast(c_plan); +CFuture* // Future +AsyncRetrieve(CTraceContext c_trace, + CSegmentInterface c_segment, + CRetrievePlan c_plan, + uint64_t timestamp, + int64_t limit_size, + bool ignore_non_pk) { + auto segment = static_cast(c_segment); + auto plan = static_cast(c_plan); - auto trace_ctx = milvus::tracer::TraceContext{ - c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; - milvus::tracer::AutoSpan span( - "SegCoreRetrieveByOffsets", &trace_ctx, true); + auto future = milvus::futures::Future::async( + milvus::futures::getGlobalCPUExecutor(), + milvus::futures::ExecutePriority::HIGH, + [c_trace, segment, plan, timestamp, limit_size, ignore_non_pk]( + milvus::futures::CancellationToken cancel_token) { + auto trace_ctx = milvus::tracer::TraceContext{ + c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; + milvus::tracer::AutoSpan span("SegCoreRetrieve", &trace_ctx, true); - auto retrieve_result = - segment->Retrieve(&trace_ctx, plan, offsets, len); + auto retrieve_result = segment->Retrieve( + &trace_ctx, plan, timestamp, limit_size, ignore_non_pk); - auto size = retrieve_result->ByteSizeLong(); - std::unique_ptr buffer(new uint8_t[size]); - retrieve_result->SerializePartialToArray(buffer.get(), size); + return CreateLeakedCRetrieveResultFromProto( + std::move(retrieve_result)); + }); + return static_cast(static_cast( + static_cast(future.release()))); +} - result->proto_blob = buffer.release(); - result->proto_size = size; +CFuture* // Future +AsyncRetrieveByOffsets(CTraceContext c_trace, + CSegmentInterface c_segment, + CRetrievePlan c_plan, + int64_t* offsets, + int64_t len) { + auto segment = static_cast(c_segment); + auto plan = static_cast(c_plan); - return milvus::SuccessCStatus(); - } catch (std::exception& e) { - return milvus::FailureCStatus(&e); - } + auto future = milvus::futures::Future::async( + milvus::futures::getGlobalCPUExecutor(), + milvus::futures::ExecutePriority::HIGH, + [c_trace, segment, plan, offsets, len]( + milvus::futures::CancellationToken cancel_token) { + auto trace_ctx = milvus::tracer::TraceContext{ + c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; + milvus::tracer::AutoSpan span( + "SegCoreRetrieveByOffsets", &trace_ctx, true); + + auto retrieve_result = + segment->Retrieve(&trace_ctx, plan, offsets, len); + + return CreateLeakedCRetrieveResultFromProto( + std::move(retrieve_result)); + }); + return static_cast(static_cast( + static_cast(future.release()))); } int64_t diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index e971c86d5b..ec25518348 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -20,6 +20,7 @@ extern "C" { #include #include "common/type_c.h" +#include "futures/future_c.h" #include "segcore/plan_c.h" #include "segcore/load_index_c.h" #include "segcore/load_field_data_c.h" @@ -43,33 +44,30 @@ ClearSegmentData(CSegmentInterface c_segment); void DeleteSearchResult(CSearchResult search_result); -CStatus -Search(CTraceContext c_trace, - CSegmentInterface c_segment, - CSearchPlan c_plan, - CPlaceholderGroup c_placeholder_group, - uint64_t timestamp, - CSearchResult* result); +CFuture* // Future +AsyncSearch(CTraceContext c_trace, + CSegmentInterface c_segment, + CSearchPlan c_plan, + CPlaceholderGroup c_placeholder_group, + uint64_t timestamp); void DeleteRetrieveResult(CRetrieveResult* retrieve_result); -CStatus -Retrieve(CTraceContext c_trace, - CSegmentInterface c_segment, - CRetrievePlan c_plan, - uint64_t timestamp, - CRetrieveResult* result, - int64_t limit_size, - bool ignore_non_pk); +CFuture* // Future +AsyncRetrieve(CTraceContext c_trace, + CSegmentInterface c_segment, + CRetrievePlan c_plan, + uint64_t timestamp, + int64_t limit_size, + bool ignore_non_pk); -CStatus -RetrieveByOffsets(CTraceContext c_trace, - CSegmentInterface c_segment, - CRetrievePlan c_plan, - CRetrieveResult* result, - int64_t* offsets, - int64_t len); +CFuture* // Future +AsyncRetrieveByOffsets(CTraceContext c_trace, + CSegmentInterface c_segment, + CRetrievePlan c_plan, + int64_t* offsets, + int64_t len); int64_t GetMemoryUsageInBytes(CSegmentInterface c_segment); diff --git a/internal/core/unittest/init_gtest.cpp b/internal/core/unittest/init_gtest.cpp index 3633a86f82..adc1b3b683 100644 --- a/internal/core/unittest/init_gtest.cpp +++ b/internal/core/unittest/init_gtest.cpp @@ -11,6 +11,7 @@ #include +#include "folly/init/Init.h" #include "test_utils/Constants.h" #include "storage/LocalChunkManagerSingleton.h" #include "storage/RemoteChunkManagerSingleton.h" @@ -19,6 +20,8 @@ int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); + folly::Init follyInit(&argc, &argv, false); + milvus::storage::LocalChunkManagerSingleton::GetInstance().Init( TestLocalPath); milvus::storage::RemoteChunkManagerSingleton::GetInstance().Init( diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index d8f7fc29b3..ae1af955db 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -34,6 +34,7 @@ #include "segcore/Reduce.h" #include "segcore/reduce_c.h" #include "segcore/segment_c.h" +#include "futures/Future.h" #include "test_utils/DataGen.h" #include "test_utils/PbHelper.h" #include "test_utils/indexbuilder_test_utils.h" @@ -64,14 +65,50 @@ CStatus CRetrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp, - CRetrieveResult* result) { - return Retrieve({}, - c_segment, - c_plan, - timestamp, - result, - DEFAULT_MAX_OUTPUT_SIZE, - false); + CRetrieveResult** result) { + auto future = AsyncRetrieve( + {}, c_segment, c_plan, timestamp, DEFAULT_MAX_OUTPUT_SIZE, false); + auto futurePtr = static_cast( + static_cast(static_cast(future))); + + std::mutex mu; + mu.lock(); + futurePtr->registerReadyCallback( + [](CLockedGoMutex* mutex) { ((std::mutex*)(mutex))->unlock(); }, + (CLockedGoMutex*)(&mu)); + mu.lock(); + + auto [retrieveResult, status] = futurePtr->leakyGet(); + if (status.error_code != 0) { + return status; + } + *result = static_cast(retrieveResult); + return status; +} + +CStatus +CRetrieveByOffsets(CSegmentInterface c_segment, + CRetrievePlan c_plan, + int64_t* offsets, + int64_t len, + CRetrieveResult** result) { + auto future = AsyncRetrieveByOffsets({}, c_segment, c_plan, offsets, len); + auto futurePtr = static_cast( + static_cast(static_cast(future))); + + std::mutex mu; + mu.lock(); + futurePtr->registerReadyCallback( + [](CLockedGoMutex* mutex) { ((std::mutex*)(mutex))->unlock(); }, + (CLockedGoMutex*)(&mu)); + mu.lock(); + + auto [retrieveResult, status] = futurePtr->leakyGet(); + if (status.error_code != 0) { + return status; + } + *result = static_cast(retrieveResult); + return status; } const char* @@ -609,15 +646,16 @@ TEST(CApiTest, MultiDeleteGrowingSegment) { plan->field_ids_ = target_field_ids; auto max_ts = dataset.timestamps_[N - 1] + 10; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); auto query_result = std::make_unique(); - auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + auto suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 0); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; // retrieve pks = {2} { @@ -633,11 +671,12 @@ TEST(CApiTest, MultiDeleteGrowingSegment) { std::make_shared(DEFAULT_PLANNODE_ID, term_expr); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); - suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 1); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; // delete pks = {2} delete_pks = {2}; @@ -658,13 +697,13 @@ TEST(CApiTest, MultiDeleteGrowingSegment) { // retrieve pks in {2} res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); - suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 0); DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); DeleteCollection(collection); DeleteSegment(segment); @@ -721,15 +760,16 @@ TEST(CApiTest, MultiDeleteSealedSegment) { plan->field_ids_ = target_field_ids; auto max_ts = dataset.timestamps_[N - 1] + 10; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; auto res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); auto query_result = std::make_unique(); - auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + auto suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 0); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; // retrieve pks = {2} { @@ -745,11 +785,12 @@ TEST(CApiTest, MultiDeleteSealedSegment) { std::make_shared(DEFAULT_PLANNODE_ID, term_expr); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); - suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 1); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; // delete pks = {2} delete_pks = {2}; @@ -770,13 +811,13 @@ TEST(CApiTest, MultiDeleteSealedSegment) { // retrieve pks in {2} res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); ASSERT_EQ(res.error_code, Success); - suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 0); DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); DeleteCollection(collection); DeleteSegment(segment); @@ -839,16 +880,17 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) { std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; res = CRetrieve( segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); ASSERT_EQ(res.error_code, Success); auto query_result = std::make_unique(); - auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + auto suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 6); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; // delete data pks = {1, 2, 3} std::vector delete_row_ids = {1, 2, 3}; @@ -873,13 +915,14 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) { ASSERT_EQ(res.error_code, Success); query_result = std::make_unique(); - suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 0); DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; DeleteCollection(collection); DeleteSegment(segment); @@ -920,16 +963,17 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) { std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; auto res = CRetrieve( segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); ASSERT_EQ(res.error_code, Success); auto query_result = std::make_unique(); - auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + auto suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 6); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; // delete data pks = {1, 2, 3} std::vector delete_row_ids = {1, 2, 3}; @@ -955,13 +999,13 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) { ASSERT_EQ(res.error_code, Success); query_result = std::make_unique(); - suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 0); DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); DeleteCollection(collection); DeleteSegment(segment); @@ -1030,16 +1074,17 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) { std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; res = CRetrieve( segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); ASSERT_EQ(res.error_code, Success); auto query_result = std::make_unique(); - auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + auto suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 0); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; // second insert data // insert data with pks = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} , timestamps = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19} @@ -1061,13 +1106,13 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) { ASSERT_EQ(res.error_code, Success); query_result = std::make_unique(); - suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 3); DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); DeleteCollection(collection); DeleteSegment(segment); @@ -1127,18 +1172,19 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnSealedSegment) { std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; auto res = CRetrieve( segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); ASSERT_EQ(res.error_code, Success); auto query_result = std::make_unique(); - auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + auto suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->ids().int_id().data().size(), 4); DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + retrieve_result = nullptr; DeleteCollection(collection); DeleteSegment(segment); @@ -1324,13 +1370,21 @@ TEST(CApiTest, RetrieveTestWithExpr) { std::vector target_field_ids{FieldId(100), FieldId(101)}; plan->field_ids_ = target_field_ids; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; auto res = CRetrieve( segment, plan.get(), dataset.timestamps_[0], &retrieve_result); ASSERT_EQ(res.error_code, Success); + // Test Retrieve by offsets. + int64_t offsets[] = {0, 1, 2}; + CRetrieveResult* retrieve_by_offsets_result = nullptr; + res = CRetrieveByOffsets( + segment, plan.get(), offsets, 3, &retrieve_by_offsets_result); + ASSERT_EQ(res.error_code, Success); + DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); + DeleteRetrieveResult(retrieve_by_offsets_result); DeleteCollection(collection); DeleteSegment(segment); } @@ -4324,13 +4378,13 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) { i8_fid, i16_fid, i32_fid, i64_fid, float_fid, double_fid}; plan->field_ids_ = target_field_ids; - CRetrieveResult retrieve_result; + CRetrieveResult* retrieve_result = nullptr; res = CRetrieve( segment, plan.get(), raw_data.timestamps_[N - 1], &retrieve_result); ASSERT_EQ(res.error_code, Success); auto query_result = std::make_unique(); - auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, - retrieve_result.proto_size); + auto suc = query_result->ParseFromArray(retrieve_result->proto_blob, + retrieve_result->proto_size); ASSERT_TRUE(suc); ASSERT_EQ(query_result->fields_data().size(), 6); auto fields_data = query_result->fields_data(); @@ -4369,7 +4423,7 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) { } DeleteRetrievePlan(plan.release()); - DeleteRetrieveResult(&retrieve_result); + DeleteRetrieveResult(retrieve_result); DeleteSegment(segment); } diff --git a/internal/core/unittest/test_futures.cpp b/internal/core/unittest/test_futures.cpp index e5f7baa23a..671cffc72a 100644 --- a/internal/core/unittest/test_futures.cpp +++ b/internal/core/unittest/test_futures.cpp @@ -206,5 +206,6 @@ TEST(Futures, Future) { ASSERT_EQ(r, nullptr); ASSERT_EQ(s.error_code, milvus::FollyCancel); + free((char*)(s.error_msg)); } } \ No newline at end of file diff --git a/internal/core/unittest/test_group_by.cpp b/internal/core/unittest/test_group_by.cpp index 59e2ef0dda..1295230f25 100644 --- a/internal/core/unittest/test_group_by.cpp +++ b/internal/core/unittest/test_group_by.cpp @@ -609,10 +609,10 @@ TEST(GroupBY, Reduce) { CSearchResult c_search_res_1; CSearchResult c_search_res_2; auto status = - Search({}, c_segment_1, c_plan, c_ph_group, 1L << 63, &c_search_res_1); + CSearch(c_segment_1, c_plan, c_ph_group, 1L << 63, &c_search_res_1); ASSERT_EQ(status.error_code, Success); status = - Search({}, c_segment_2, c_plan, c_ph_group, 1L << 63, &c_search_res_2); + CSearch(c_segment_2, c_plan, c_ph_group, 1L << 63, &c_search_res_2); ASSERT_EQ(status.error_code, Success); std::vector results; results.push_back(c_search_res_1); diff --git a/internal/core/unittest/test_utils/c_api_test_utils.h b/internal/core/unittest/test_utils/c_api_test_utils.h index 225a00e6b3..cabf6ec432 100644 --- a/internal/core/unittest/test_utils/c_api_test_utils.h +++ b/internal/core/unittest/test_utils/c_api_test_utils.h @@ -28,6 +28,7 @@ #include "segcore/Reduce.h" #include "segcore/reduce_c.h" #include "segcore/segment_c.h" +#include "futures/Future.h" #include "DataGen.h" #include "PbHelper.h" #include "c_api_test_utils.h" @@ -147,8 +148,24 @@ CSearch(CSegmentInterface c_segment, CPlaceholderGroup c_placeholder_group, uint64_t timestamp, CSearchResult* result) { - return Search( - {}, c_segment, c_plan, c_placeholder_group, timestamp, result); + auto future = + AsyncSearch({}, c_segment, c_plan, c_placeholder_group, timestamp); + auto futurePtr = static_cast( + static_cast(static_cast(future))); + + std::mutex mu; + mu.lock(); + futurePtr->registerReadyCallback( + [](CLockedGoMutex* mutex) { ((std::mutex*)(mutex))->unlock(); }, + (CLockedGoMutex*)(&mu)); + mu.lock(); + + auto [searchResult, status] = futurePtr->leakyGet(); + if (status.error_code != 0) { + return status; + } + *result = static_cast(searchResult); + return status; } } // namespace diff --git a/internal/querynodev2/segments/cgo_util.go b/internal/querynodev2/segments/cgo_util.go index 2e8c32e402..f82d25ae29 100644 --- a/internal/querynodev2/segments/cgo_util.go +++ b/internal/querynodev2/segments/cgo_util.go @@ -28,6 +28,7 @@ import "C" import ( "context" + "math" "unsafe" "github.com/golang/protobuf/proto" @@ -55,14 +56,9 @@ func HandleCStatus(ctx context.Context, status *C.CStatus, extraInfo string, fie return err } -// HandleCProto deal with the result proto returned from CGO -func HandleCProto(cRes *C.CProto, msg proto.Message) error { - // Standalone CProto is protobuf created by C side, - // Passed from c side - // memory is managed manually - lease, blob := cgoconverter.UnsafeGoBytes(&cRes.proto_blob, int(cRes.proto_size)) - defer cgoconverter.Release(lease) - +// UnmarshalCProto unmarshal the proto from C memory +func UnmarshalCProto(cRes *C.CProto, msg proto.Message) error { + blob := (*(*[math.MaxInt32]byte)(cRes.proto_blob))[:int(cRes.proto_size):int(cRes.proto_size)] return proto.Unmarshal(blob, msg) } diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index ea58f2802b..08b7aaa48a 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -163,6 +163,10 @@ func (suite *RetrieveSuite) TestRetrieveSealed() { suite.NoError(err) suite.Len(res[0].Result.Offset, 3) suite.manager.Segment.Unpin(segments) + + resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), plan, []int64{0, 1}) + suite.NoError(err) + suite.Len(resultByOffsets.Offset, 0) } func (suite *RetrieveSuite) TestRetrieveGrowing() { @@ -182,6 +186,10 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() { suite.NoError(err) suite.Len(res[0].Result.Offset, 3) suite.manager.Segment.Unpin(segments) + + resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), plan, []int64{0, 1}) + suite.NoError(err) + suite.Len(resultByOffsets.Offset, 0) } func (suite *RetrieveSuite) TestRetrieveStreamSealed() { diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 83404096cd..08e6707b1b 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -17,8 +17,9 @@ package segments /* -#cgo pkg-config: milvus_segcore +#cgo pkg-config: milvus_segcore milvus_futures +#include "futures/future_c.h" #include "segcore/collection_c.h" #include "segcore/plan_c.h" #include "segcore/reduce_c.h" @@ -53,6 +54,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/segments/state" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/cgo" typeutil_internal "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -565,34 +567,39 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S defer s.ptrLock.RUnlock() traceCtx := ParseCTraceContext(ctx) + defer runtime.KeepAlive(traceCtx) + defer runtime.KeepAlive(searchReq) hasIndex := s.ExistIndex(searchReq.searchFieldID) log = log.With(zap.Bool("withIndex", hasIndex)) log.Debug("search segment...") - var searchResult SearchResult - var status C.CStatus - GetSQPool().Submit(func() (any, error) { - tr := timerecord.NewTimeRecorder("cgoSearch") - status = C.Search(traceCtx.ctx, - s.ptr, - searchReq.plan.cSearchPlan, - searchReq.cPlaceholderGroup, - C.uint64_t(searchReq.mvccTimestamp), - &searchResult.cSearchResult, - ) - runtime.KeepAlive(traceCtx) - metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) - return nil, nil - }).Await() - if err := HandleCStatus(ctx, &status, "Search failed", - zap.Int64("collectionID", s.Collection()), - zap.Int64("segmentID", s.ID()), - zap.String("segmentType", s.segmentType.String())); err != nil { + tr := timerecord.NewTimeRecorder("cgoSearch") + + future := cgo.Async( + ctx, + func() cgo.CFuturePtr { + return (cgo.CFuturePtr)(C.AsyncSearch( + traceCtx.ctx, + s.ptr, + searchReq.plan.cSearchPlan, + searchReq.cPlaceholderGroup, + C.uint64_t(searchReq.mvccTimestamp), + )) + }, + cgo.WithName("search"), + ) + defer future.Release() + result, err := future.BlockAndLeakyGet() + if err != nil { + log.Warn("Search failed") return nil, err } + metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) log.Debug("search segment done") - return &searchResult, nil + return &SearchResult{ + cSearchResult: (C.CSearchResult)(result), + }, nil } func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { @@ -612,69 +619,65 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco log.Debug("begin to retrieve") traceCtx := ParseCTraceContext(ctx) + defer runtime.KeepAlive(traceCtx) + defer runtime.KeepAlive(plan) maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() - var retrieveResult RetrieveResult - var status C.CStatus - GetSQPool().Submit(func() (any, error) { - ts := C.uint64_t(plan.Timestamp) - tr := timerecord.NewTimeRecorder("cgoRetrieve") - status = C.Retrieve(traceCtx.ctx, - s.ptr, - plan.cRetrievePlan, - ts, - &retrieveResult.cRetrieveResult, - C.int64_t(maxLimitSize), - C.bool(plan.ignoreNonPk)) - runtime.KeepAlive(traceCtx) + tr := timerecord.NewTimeRecorder("cgoRetrieve") - metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), - metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) - log.Debug("cgo retrieve done", zap.Duration("timeTaken", tr.ElapseSpan())) - return nil, nil - }).Await() - - if err := HandleCStatus(ctx, &status, "Retrieve failed", - zap.Int64("collectionID", s.Collection()), - zap.Int64("partitionID", s.Partition()), - zap.Int64("segmentID", s.ID()), - zap.Int64("msgID", plan.msgID), - zap.String("segmentType", s.segmentType.String())); err != nil { + future := cgo.Async( + ctx, + func() cgo.CFuturePtr { + return (cgo.CFuturePtr)(C.AsyncRetrieve( + traceCtx.ctx, + s.ptr, + plan.cRetrievePlan, + C.uint64_t(plan.Timestamp), + C.int64_t(maxLimitSize), + C.bool(plan.ignoreNonPk), + )) + }, + cgo.WithName("retrieve"), + ) + defer future.Release() + result, err := future.BlockAndLeakyGet() + if err != nil { + log.Warn("Retrieve failed") return nil, err } + defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result)) + + metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), + metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "partial-segcore-results-deserialization") defer span.End() - result := new(segcorepb.RetrieveResults) - if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil { + retrieveResult := new(segcorepb.RetrieveResults) + if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil { + log.Warn("unmarshal retrieve result failed", zap.Error(err)) return nil, err } log.Debug("retrieve segment done", - zap.Int("resultNum", len(result.Offset)), + zap.Int("resultNum", len(retrieveResult.Offset)), ) - // Sort was done by the segcore. // sort.Sort(&byPK{result}) - return result, nil + return retrieveResult, nil } func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { + if len(offsets) == 0 { + return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets") + } + if !s.ptrLock.RLockIf(state.IsNotReleased) { // TODO: check if the segment is readable but not released. too many related logic need to be refactor. return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } defer s.ptrLock.RUnlock() - if s.ptr == nil { - return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") - } - - if len(offsets) == 0 { - return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets") - } - fields := []zap.Field{ zap.Int64("collectionID", s.Collection()), zap.Int64("partitionID", s.Partition()), @@ -686,40 +689,49 @@ func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan log := log.Ctx(ctx).With(fields...) log.Debug("begin to retrieve by offsets") - - traceCtx := ParseCTraceContext(ctx) - - var retrieveResult RetrieveResult - var status C.CStatus - tr := timerecord.NewTimeRecorder("cgoRetrieveByOffsets") - status = C.RetrieveByOffsets(traceCtx.ctx, - s.ptr, - plan.cRetrievePlan, - &retrieveResult.cRetrieveResult, - (*C.int64_t)(unsafe.Pointer(&offsets[0])), - C.int64_t(len(offsets))) - runtime.KeepAlive(traceCtx) + traceCtx := ParseCTraceContext(ctx) + defer runtime.KeepAlive(traceCtx) + defer runtime.KeepAlive(plan) + defer runtime.KeepAlive(offsets) + + future := cgo.Async( + ctx, + func() cgo.CFuturePtr { + return (cgo.CFuturePtr)(C.AsyncRetrieveByOffsets( + traceCtx.ctx, + s.ptr, + plan.cRetrievePlan, + (*C.int64_t)(unsafe.Pointer(&offsets[0])), + C.int64_t(len(offsets)), + )) + }, + cgo.WithName("retrieve-by-offsets"), + ) + defer future.Release() + result, err := future.BlockAndLeakyGet() + if err != nil { + log.Warn("RetrieveByOffsets failed") + return nil, err + } + defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result)) metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) - log.Debug("cgo retrieve by offsets done", zap.Duration("timeTaken", tr.ElapseSpan())) - - if err := HandleCStatus(ctx, &status, "RetrieveByOffsets failed", fields...); err != nil { - return nil, err - } _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "reduced-segcore-results-deserialization") defer span.End() - result := new(segcorepb.RetrieveResults) - if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil { + retrieveResult := new(segcorepb.RetrieveResults) + if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil { + log.Warn("unmarshal retrieve by offsets result failed", zap.Error(err)) return nil, err } - log.Debug("retrieve by segment offsets done") - - return result, nil + log.Debug("retrieve by segment offsets done", + zap.Int("resultNum", len(retrieveResult.Offset)), + ) + return retrieveResult, nil } func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (dataPath string, offsetInBinlog int64) { diff --git a/internal/util/cgo/executor.go b/internal/util/cgo/executor.go new file mode 100644 index 0000000000..a589513469 --- /dev/null +++ b/internal/util/cgo/executor.go @@ -0,0 +1,36 @@ +package cgo + +/* +#cgo pkg-config: milvus_futures + +#include "futures/future_c.h" +*/ +import "C" + +import ( + "math" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// initExecutor initialize underlying cgo thread pool. +func initExecutor() { + pt := paramtable.Get() + initPoolSize := int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) + C.executor_set_thread_num(C.int(initPoolSize)) + + resetThreadNum := func(evt *config.Event) { + if evt.HasUpdated { + pt := paramtable.Get() + newSize := int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) + log.Info("reset cgo thread num", zap.Int("thread_num", newSize)) + C.executor_set_thread_num(C.int(newSize)) + } + } + pt.Watch(pt.QueryNodeCfg.MaxReadConcurrency.Key, config.NewHandler("cgo."+pt.QueryNodeCfg.MaxReadConcurrency.Key, resetThreadNum)) + pt.Watch(pt.QueryNodeCfg.CGOPoolSizeRatio.Key, config.NewHandler("cgo."+pt.QueryNodeCfg.CGOPoolSizeRatio.Key, resetThreadNum)) +} diff --git a/internal/util/cgo/futures.go b/internal/util/cgo/futures.go index aef27fe669..3b6aadf454 100644 --- a/internal/util/cgo/futures.go +++ b/internal/util/cgo/futures.go @@ -75,6 +75,8 @@ type ( // Async is a helper function to call a C async function that returns a future. func Async(ctx context.Context, f CGOAsyncFunction, opts ...Opt) Future { + initCGO() + options := getDefaultOpt() // apply options. for _, opt := range opts { diff --git a/internal/util/cgo/futures_test.go b/internal/util/cgo/futures_test.go index 120ba922a5..5f2a6360bc 100644 --- a/internal/util/cgo/futures_test.go +++ b/internal/util/cgo/futures_test.go @@ -18,7 +18,7 @@ import ( func TestMain(m *testing.M) { paramtable.Init() - InitCGO() + initCGO() exitCode := m.Run() if exitCode > 0 { os.Exit(exitCode) @@ -47,10 +47,6 @@ func TestFutureWithSuccessCase(t *testing.T) { _, err = future.BlockAndLeakyGet() assert.ErrorIs(t, err, ErrConsumed) - - assert.Eventually(t, func() bool { - return unreleasedCnt.Load() == 0 - }, time.Second, time.Millisecond*100) } func TestFutureWithCaseNoInterrupt(t *testing.T) { @@ -186,10 +182,6 @@ func TestFutures(t *testing.T) { assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Nil(t, result) runtime.GC() - - assert.Eventually(t, func() bool { - return unreleasedCnt.Load() == 0 - }, time.Second, time.Millisecond*100) } func TestConcurrent(t *testing.T) { @@ -259,8 +251,14 @@ func TestConcurrent(t *testing.T) { }) defer future.Release() result, err := future.BlockAndLeakyGet() - assert.NoError(t, err) - assert.Equal(t, 0, getCInt(result)) + if err == nil { + assert.Equal(t, 0, getCInt(result)) + } else { + // the future may be queued and not started, + // so the underlying task may be throw a cancel exception if it's not started. + assert.ErrorIs(t, err, merr.ErrSegcoreFollyCancel) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) + } freeCInt(result) }() } @@ -271,8 +269,4 @@ func TestConcurrent(t *testing.T) { return stat.ActiveCount == 0 }, 5*time.Second, 100*time.Millisecond) runtime.GC() - - assert.Eventually(t, func() bool { - return unreleasedCnt.Load() == 0 - }, time.Second, time.Millisecond*100) } diff --git a/internal/util/cgo/futures_test_case.go b/internal/util/cgo/futures_test_case.go index 1e3afb832c..3cc933c095 100644 --- a/internal/util/cgo/futures_test_case.go +++ b/internal/util/cgo/futures_test_case.go @@ -16,8 +16,6 @@ import ( "context" "time" "unsafe" - - "go.uber.org/atomic" ) const ( @@ -27,8 +25,6 @@ const ( caseNoThrowSegcoreException int = 3 ) -var unreleasedCnt = atomic.NewInt32(0) - type testCase struct { interval time.Duration loopCnt int @@ -39,12 +35,7 @@ func createFutureWithTestCase(ctx context.Context, testCase testCase) Future { f := func() CFuturePtr { return (CFuturePtr)(C.future_create_test_case(C.int(testCase.interval.Milliseconds()), C.int(testCase.loopCnt), C.int(testCase.caseNo))) } - future := Async(ctx, f, - WithName("createFutureWithTestCase"), - WithReleaser(func() { - unreleasedCnt.Dec() - })) - unreleasedCnt.Inc() + future := Async(ctx, f, WithName("createFutureWithTestCase")) return future } diff --git a/internal/util/cgo/manager_active.go b/internal/util/cgo/manager_active.go index c003a2c15c..37c6011f18 100644 --- a/internal/util/cgo/manager_active.go +++ b/internal/util/cgo/manager_active.go @@ -1,14 +1,12 @@ package cgo import ( - "math" "reflect" "sync" "go.uber.org/atomic" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -23,20 +21,12 @@ var ( initOnce sync.Once ) -// InitCGO initializes the cgo caller and future manager. -// Please call this function before using any cgo utilities. -func InitCGO() { +// initCGO initializes the cgo caller and future manager. +func initCGO() { initOnce.Do(func() { nodeID := paramtable.GetStringNodeID() - chSize := int64(math.Ceil(float64(hardware.GetCPUNum()) * paramtable.Get().QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) - if chSize <= 0 { - chSize = 1 - } - caller = &cgoCaller{ - // TODO: temporary solution, need to find a better way to set the pool size. - ch: make(chan struct{}, chSize), - nodeID: nodeID, - } + initCaller(nodeID) + initExecutor() futureManager = newActiveFutureManager(nodeID) futureManager.Run() }) diff --git a/internal/util/cgo/pool.go b/internal/util/cgo/pool.go index f75090b5f3..789db284e9 100644 --- a/internal/util/cgo/pool.go +++ b/internal/util/cgo/pool.go @@ -1,14 +1,28 @@ package cgo import ( + "math" "runtime" "time" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) var caller *cgoCaller +func initCaller(nodeID string) { + chSize := int64(math.Ceil(float64(hardware.GetCPUNum()) * paramtable.Get().QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) + if chSize <= 0 { + chSize = 1 + } + caller = &cgoCaller{ + ch: make(chan struct{}, chSize), + nodeID: nodeID, + } +} + // getCGOCaller returns the cgoCaller instance. func getCGOCaller() *cgoCaller { return caller diff --git a/pkg/metrics/cgo_metrics.go b/pkg/metrics/cgo_metrics.go index 77815e4952..d237493a40 100644 --- a/pkg/metrics/cgo_metrics.go +++ b/pkg/metrics/cgo_metrics.go @@ -2,14 +2,31 @@ package metrics import ( "sync" + "time" "github.com/prometheus/client_golang/prometheus" ) var ( - subsystemCGO = "cgo" - cgoLabelName = "name" - once sync.Once + subsystemCGO = "cgo" + cgoLabelName = "name" + once sync.Once + bucketsForCGOCall = []float64{ + 10 * time.Nanosecond.Seconds(), + 100 * time.Nanosecond.Seconds(), + 250 * time.Nanosecond.Seconds(), + 500 * time.Nanosecond.Seconds(), + time.Microsecond.Seconds(), + 10 * time.Microsecond.Seconds(), + 20 * time.Microsecond.Seconds(), + 50 * time.Microsecond.Seconds(), + 100 * time.Microsecond.Seconds(), + 250 * time.Microsecond.Seconds(), + 500 * time.Microsecond.Seconds(), + time.Millisecond.Seconds(), + 2 * time.Millisecond.Seconds(), + 10 * time.Millisecond.Seconds(), + } ActiveFutureTotal = prometheus.NewGaugeVec( prometheus.GaugeOpts{ @@ -38,6 +55,7 @@ var ( Subsystem: subsystemCGO, Name: "cgo_duration_seconds", Help: "Histogram of cgo call duration in seconds.", + Buckets: bucketsForCGOCall, }, []string{ nodeIDLabelName, cgoLabelName, @@ -50,6 +68,7 @@ var ( Subsystem: subsystemCGO, Name: "cgo_queue_duration_seconds", Help: "Duration of cgo call in queue.", + Buckets: bucketsForCGOCall, }, []string{ nodeIDLabelName, }, @@ -59,8 +78,9 @@ var ( // RegisterCGOMetrics registers the cgo metrics. func RegisterCGOMetrics(registry *prometheus.Registry) { once.Do(func() { - prometheus.MustRegister(RunningCgoCallTotal) - prometheus.MustRegister(CGODuration) - prometheus.MustRegister(CGOQueueDuration) + registry.MustRegister(ActiveFutureTotal) + registry.MustRegister(RunningCgoCallTotal) + registry.MustRegister(CGODuration) + registry.MustRegister(CGOQueueDuration) }) } diff --git a/pkg/metrics/metrics_test.go b/pkg/metrics/metrics_test.go index 99c2d8d414..03bb0d879e 100644 --- a/pkg/metrics/metrics_test.go +++ b/pkg/metrics/metrics_test.go @@ -38,6 +38,7 @@ func TestRegisterMetrics(t *testing.T) { RegisterMetaMetrics(r) RegisterStorageMetrics(r) RegisterMsgStreamMetrics(r) + RegisterCGOMetrics(r) }) } diff --git a/pkg/metrics/querynode_metrics.go b/pkg/metrics/querynode_metrics.go index 500d5d6ce3..4675a3431f 100644 --- a/pkg/metrics/querynode_metrics.go +++ b/pkg/metrics/querynode_metrics.go @@ -792,6 +792,8 @@ func RegisterQueryNode(registry *prometheus.Registry) { registry.MustRegister(QueryNodeSegmentPruneRatio) registry.MustRegister(QueryNodeApplyBFCost) registry.MustRegister(QueryNodeForwardDeleteCost) + // Add cgo metrics + RegisterCGOMetrics(registry) } func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) {