enhance: async search and retrieve in cgo (#33228)

issue: #30926, #33132
related pr: #33133

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
chyezh 2024-06-22 09:38:02 +08:00 committed by GitHub
parent c85644e1b3
commit 259a682673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 495 additions and 350 deletions

View File

@ -16,30 +16,14 @@
namespace milvus::futures { namespace milvus::futures {
const int kNumPriority = 3; const int kNumPriority = 3;
const int kMaxQueueSizeFactor = 16;
folly::Executor::KeepAlive<> folly::CPUThreadPoolExecutor*
getGlobalCPUExecutor() { getGlobalCPUExecutor() {
static ExecutorSingleton singleton; static folly::CPUThreadPoolExecutor executor(
return singleton.GetCPUExecutor(); std::thread::hardware_concurrency(),
} folly::CPUThreadPoolExecutor::makeDefaultPriorityQueue(kNumPriority),
std::make_shared<folly::NamedThreadFactory>("MILVUS_FUTURE_CPU_"));
folly::Executor::KeepAlive<> return &executor;
ExecutorSingleton::GetCPUExecutor() {
// TODO: fix the executor with a non-block way.
std::call_once(cpu_executor_once_, [this]() {
int num_threads = milvus::CPU_NUM;
auto num_priority = kNumPriority;
auto max_queue_size = num_threads * kMaxQueueSizeFactor;
cpu_executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(
num_threads,
std::make_unique<folly::PriorityLifoSemMPMCQueue<
folly::CPUThreadPoolExecutor::CPUTask,
folly::QueueBehaviorIfFull::BLOCK>>(num_priority,
max_queue_size),
std::make_shared<folly::NamedThreadFactory>("MILVUS_CPU_"));
});
return folly::getKeepAliveToken(cpu_executor_.get());
} }
}; // namespace milvus::futures }; // namespace milvus::futures

View File

@ -18,23 +18,13 @@
namespace milvus::futures { namespace milvus::futures {
folly::Executor::KeepAlive<> namespace ExecutePriority {
const int LOW = 2;
const int NORMAL = 1;
const int HIGH = 0;
} // namespace ExecutePriority
folly::CPUThreadPoolExecutor*
getGlobalCPUExecutor(); getGlobalCPUExecutor();
class ExecutorSingleton {
public:
ExecutorSingleton() = default;
ExecutorSingleton(const ExecutorSingleton&) = delete;
ExecutorSingleton(ExecutorSingleton&&) noexcept = delete;
folly::Executor::KeepAlive<>
GetCPUExecutor();
private:
std::unique_ptr<folly::Executor> cpu_executor_;
std::once_flag cpu_executor_once_;
};
}; // namespace milvus::futures }; // namespace milvus::futures

View File

@ -16,7 +16,6 @@
#include <folly/CancellationToken.h> #include <folly/CancellationToken.h>
#include <folly/futures/Future.h> #include <folly/futures/Future.h>
#include <folly/futures/SharedPromise.h> #include <folly/futures/SharedPromise.h>
#include "future_c_types.h" #include "future_c_types.h"
#include "LeakyResult.h" #include "LeakyResult.h"
#include "Ready.h" #include "Ready.h"
@ -56,6 +55,8 @@ class IFuture {
releaseLeakedFuture(IFuture* future) { releaseLeakedFuture(IFuture* future) {
delete future; delete future;
} }
virtual ~IFuture() = default;
}; };
/// @brief a class that represents a cancellation token /// @brief a class that represents a cancellation token
@ -176,6 +177,7 @@ class Future : public IFuture {
CancellationToken(cancellation_source_.getToken()); CancellationToken(cancellation_source_.getToken());
auto runner = [fn = std::forward<Fn>(fn), auto runner = [fn = std::forward<Fn>(fn),
cancellation_token = std::move(cancellation_token)]() { cancellation_token = std::move(cancellation_token)]() {
cancellation_token.throwIfCancelled();
return fn(cancellation_token); return fn(cancellation_token);
}; };

View File

@ -14,6 +14,8 @@
#include "future_c.h" #include "future_c.h"
#include "folly/init/Init.h" #include "folly/init/Init.h"
#include "Future.h" #include "Future.h"
#include "Executor.h"
#include "log/Log.h"
extern "C" void extern "C" void
future_cancel(CFuture* future) { future_cancel(CFuture* future) {
@ -49,3 +51,10 @@ future_destroy(CFuture* future) {
milvus::futures::IFuture::releaseLeakedFuture( milvus::futures::IFuture::releaseLeakedFuture(
static_cast<milvus::futures::IFuture*>(static_cast<void*>(future))); static_cast<milvus::futures::IFuture*>(static_cast<void*>(future)));
} }
extern "C" void
executor_set_thread_num(int thread_num) {
milvus::futures::getGlobalCPUExecutor()->setNumThreads(thread_num);
LOG_INFO("future executor setup cpu executor with thread num: {}",
thread_num);
}

View File

@ -39,6 +39,9 @@ future_create_test_case(int interval, int loop_cnt, int caseNo);
void void
future_destroy(CFuture* future); future_destroy(CFuture* future);
void
executor_set_thread_num(int thread_num);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -16,7 +16,7 @@ extern "C" CFuture*
future_create_test_case(int interval, int loop_cnt, int case_no) { future_create_test_case(int interval, int loop_cnt, int case_no) {
auto future = milvus::futures::Future<int>::async( auto future = milvus::futures::Future<int>::async(
milvus::futures::getGlobalCPUExecutor(), milvus::futures::getGlobalCPUExecutor(),
0, milvus::futures::ExecutePriority::HIGH,
[interval = interval, loop_cnt = loop_cnt, case_no = case_no]( [interval = interval, loop_cnt = loop_cnt, case_no = case_no](
milvus::futures::CancellationToken token) { milvus::futures::CancellationToken token) {
for (int i = 0; i < loop_cnt; i++) { for (int i = 0; i < loop_cnt; i++) {

View File

@ -42,6 +42,6 @@ set(SEGCORE_FILES
check_vec_index_c.cpp) check_vec_index_c.cpp)
add_library(milvus_segcore SHARED ${SEGCORE_FILES}) add_library(milvus_segcore SHARED ${SEGCORE_FILES})
target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage) target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage milvus_futures)
install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}") install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}")

View File

@ -27,6 +27,8 @@
#include "segcore/SegmentSealedImpl.h" #include "segcore/SegmentSealedImpl.h"
#include "segcore/Utils.h" #include "segcore/Utils.h"
#include "storage/Util.h" #include "storage/Util.h"
#include "futures/Future.h"
#include "futures/Executor.h"
#include "storage/space.h" #include "storage/space.h"
////////////////////////////// common interfaces ////////////////////////////// ////////////////////////////// common interfaces //////////////////////////////
@ -82,19 +84,22 @@ DeleteSearchResult(CSearchResult search_result) {
delete res; delete res;
} }
CStatus CFuture* // Future<milvus::SearchResult*>
Search(CTraceContext c_trace, AsyncSearch(CTraceContext c_trace,
CSegmentInterface c_segment, CSegmentInterface c_segment,
CSearchPlan c_plan, CSearchPlan c_plan,
CPlaceholderGroup c_placeholder_group, CPlaceholderGroup c_placeholder_group,
uint64_t timestamp, uint64_t timestamp) {
CSearchResult* result) {
try {
auto segment = (milvus::segcore::SegmentInterface*)c_segment; auto segment = (milvus::segcore::SegmentInterface*)c_segment;
auto plan = (milvus::query::Plan*)c_plan; auto plan = (milvus::query::Plan*)c_plan;
auto phg_ptr = reinterpret_cast<const milvus::query::PlaceholderGroup*>( auto phg_ptr = reinterpret_cast<const milvus::query::PlaceholderGroup*>(
c_placeholder_group); c_placeholder_group);
auto future = milvus::futures::Future<milvus::SearchResult>::async(
milvus::futures::getGlobalCPUExecutor(),
milvus::futures::ExecutePriority::HIGH,
[c_trace, segment, plan, phg_ptr, timestamp](
milvus::futures::CancellationToken cancel_token) {
// save trace context into search_info // save trace context into search_info
auto& trace_ctx = plan->plan_node_->search_info_.trace_ctx_; auto& trace_ctx = plan->plan_node_->search_info_.trace_ctx_;
trace_ctx.traceID = c_trace.traceID; trace_ctx.traceID = c_trace.traceID;
@ -111,33 +116,56 @@ Search(CTraceContext c_trace,
dis *= -1; dis *= -1;
} }
} }
*result = search_result.release();
span->End(); span->End();
milvus::tracer::CloseRootSpan(); milvus::tracer::CloseRootSpan();
return milvus::SuccessCStatus(); return search_result.release();
} catch (std::exception& e) { });
return milvus::FailureCStatus(&e); return static_cast<CFuture*>(static_cast<void*>(
} static_cast<milvus::futures::IFuture*>(future.release())));
} }
void void
DeleteRetrieveResult(CRetrieveResult* retrieve_result) { DeleteRetrieveResult(CRetrieveResult* retrieve_result) {
std::free(const_cast<void*>(retrieve_result->proto_blob)); delete[] static_cast<uint8_t*>(
const_cast<void*>(retrieve_result->proto_blob));
delete retrieve_result;
} }
CStatus /// Create a leaked CRetrieveResult from a proto.
Retrieve(CTraceContext c_trace, /// Should be released by DeleteRetrieveResult.
CRetrieveResult*
CreateLeakedCRetrieveResultFromProto(
std::unique_ptr<milvus::proto::segcore::RetrieveResults> retrieve_result) {
auto size = retrieve_result->ByteSizeLong();
auto buffer = new uint8_t[size];
try {
retrieve_result->SerializePartialToArray(buffer, size);
} catch (std::exception& e) {
delete[] buffer;
throw;
}
auto result = new CRetrieveResult();
result->proto_blob = buffer;
result->proto_size = size;
return result;
}
CFuture* // Future<CRetrieveResult>
AsyncRetrieve(CTraceContext c_trace,
CSegmentInterface c_segment, CSegmentInterface c_segment,
CRetrievePlan c_plan, CRetrievePlan c_plan,
uint64_t timestamp, uint64_t timestamp,
CRetrieveResult* result,
int64_t limit_size, int64_t limit_size,
bool ignore_non_pk) { bool ignore_non_pk) {
try { auto segment = static_cast<milvus::segcore::SegmentInterface*>(c_segment);
auto segment =
static_cast<milvus::segcore::SegmentInterface*>(c_segment);
auto plan = static_cast<const milvus::query::RetrievePlan*>(c_plan); auto plan = static_cast<const milvus::query::RetrievePlan*>(c_plan);
auto future = milvus::futures::Future<CRetrieveResult>::async(
milvus::futures::getGlobalCPUExecutor(),
milvus::futures::ExecutePriority::HIGH,
[c_trace, segment, plan, timestamp, limit_size, ignore_non_pk](
milvus::futures::CancellationToken cancel_token) {
auto trace_ctx = milvus::tracer::TraceContext{ auto trace_ctx = milvus::tracer::TraceContext{
c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; c_trace.traceID, c_trace.spanID, c_trace.traceFlags};
milvus::tracer::AutoSpan span("SegCoreRetrieve", &trace_ctx, true); milvus::tracer::AutoSpan span("SegCoreRetrieve", &trace_ctx, true);
@ -145,31 +173,27 @@ Retrieve(CTraceContext c_trace,
auto retrieve_result = segment->Retrieve( auto retrieve_result = segment->Retrieve(
&trace_ctx, plan, timestamp, limit_size, ignore_non_pk); &trace_ctx, plan, timestamp, limit_size, ignore_non_pk);
auto size = retrieve_result->ByteSizeLong(); return CreateLeakedCRetrieveResultFromProto(
std::unique_ptr<uint8_t[]> buffer(new uint8_t[size]); std::move(retrieve_result));
retrieve_result->SerializePartialToArray(buffer.get(), size); });
return static_cast<CFuture*>(static_cast<void*>(
result->proto_blob = buffer.release(); static_cast<milvus::futures::IFuture*>(future.release())));
result->proto_size = size;
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(&e);
}
} }
CStatus CFuture* // Future<CRetrieveResult>
RetrieveByOffsets(CTraceContext c_trace, AsyncRetrieveByOffsets(CTraceContext c_trace,
CSegmentInterface c_segment, CSegmentInterface c_segment,
CRetrievePlan c_plan, CRetrievePlan c_plan,
CRetrieveResult* result,
int64_t* offsets, int64_t* offsets,
int64_t len) { int64_t len) {
try { auto segment = static_cast<milvus::segcore::SegmentInterface*>(c_segment);
auto segment =
static_cast<milvus::segcore::SegmentInterface*>(c_segment);
auto plan = static_cast<const milvus::query::RetrievePlan*>(c_plan); auto plan = static_cast<const milvus::query::RetrievePlan*>(c_plan);
auto future = milvus::futures::Future<CRetrieveResult>::async(
milvus::futures::getGlobalCPUExecutor(),
milvus::futures::ExecutePriority::HIGH,
[c_trace, segment, plan, offsets, len](
milvus::futures::CancellationToken cancel_token) {
auto trace_ctx = milvus::tracer::TraceContext{ auto trace_ctx = milvus::tracer::TraceContext{
c_trace.traceID, c_trace.spanID, c_trace.traceFlags}; c_trace.traceID, c_trace.spanID, c_trace.traceFlags};
milvus::tracer::AutoSpan span( milvus::tracer::AutoSpan span(
@ -178,17 +202,11 @@ RetrieveByOffsets(CTraceContext c_trace,
auto retrieve_result = auto retrieve_result =
segment->Retrieve(&trace_ctx, plan, offsets, len); segment->Retrieve(&trace_ctx, plan, offsets, len);
auto size = retrieve_result->ByteSizeLong(); return CreateLeakedCRetrieveResultFromProto(
std::unique_ptr<uint8_t[]> buffer(new uint8_t[size]); std::move(retrieve_result));
retrieve_result->SerializePartialToArray(buffer.get(), size); });
return static_cast<CFuture*>(static_cast<void*>(
result->proto_blob = buffer.release(); static_cast<milvus::futures::IFuture*>(future.release())));
result->proto_size = size;
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(&e);
}
} }
int64_t int64_t

View File

@ -20,6 +20,7 @@ extern "C" {
#include <stdint.h> #include <stdint.h>
#include "common/type_c.h" #include "common/type_c.h"
#include "futures/future_c.h"
#include "segcore/plan_c.h" #include "segcore/plan_c.h"
#include "segcore/load_index_c.h" #include "segcore/load_index_c.h"
#include "segcore/load_field_data_c.h" #include "segcore/load_field_data_c.h"
@ -43,31 +44,28 @@ ClearSegmentData(CSegmentInterface c_segment);
void void
DeleteSearchResult(CSearchResult search_result); DeleteSearchResult(CSearchResult search_result);
CStatus CFuture* // Future<CSearchResultBody>
Search(CTraceContext c_trace, AsyncSearch(CTraceContext c_trace,
CSegmentInterface c_segment, CSegmentInterface c_segment,
CSearchPlan c_plan, CSearchPlan c_plan,
CPlaceholderGroup c_placeholder_group, CPlaceholderGroup c_placeholder_group,
uint64_t timestamp, uint64_t timestamp);
CSearchResult* result);
void void
DeleteRetrieveResult(CRetrieveResult* retrieve_result); DeleteRetrieveResult(CRetrieveResult* retrieve_result);
CStatus CFuture* // Future<CRetrieveResult>
Retrieve(CTraceContext c_trace, AsyncRetrieve(CTraceContext c_trace,
CSegmentInterface c_segment, CSegmentInterface c_segment,
CRetrievePlan c_plan, CRetrievePlan c_plan,
uint64_t timestamp, uint64_t timestamp,
CRetrieveResult* result,
int64_t limit_size, int64_t limit_size,
bool ignore_non_pk); bool ignore_non_pk);
CStatus CFuture* // Future<CRetrieveResult>
RetrieveByOffsets(CTraceContext c_trace, AsyncRetrieveByOffsets(CTraceContext c_trace,
CSegmentInterface c_segment, CSegmentInterface c_segment,
CRetrievePlan c_plan, CRetrievePlan c_plan,
CRetrieveResult* result,
int64_t* offsets, int64_t* offsets,
int64_t len); int64_t len);

View File

@ -11,6 +11,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "folly/init/Init.h"
#include "test_utils/Constants.h" #include "test_utils/Constants.h"
#include "storage/LocalChunkManagerSingleton.h" #include "storage/LocalChunkManagerSingleton.h"
#include "storage/RemoteChunkManagerSingleton.h" #include "storage/RemoteChunkManagerSingleton.h"
@ -19,6 +20,8 @@
int int
main(int argc, char** argv) { main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv); ::testing::InitGoogleTest(&argc, argv);
folly::Init follyInit(&argc, &argv, false);
milvus::storage::LocalChunkManagerSingleton::GetInstance().Init( milvus::storage::LocalChunkManagerSingleton::GetInstance().Init(
TestLocalPath); TestLocalPath);
milvus::storage::RemoteChunkManagerSingleton::GetInstance().Init( milvus::storage::RemoteChunkManagerSingleton::GetInstance().Init(

View File

@ -34,6 +34,7 @@
#include "segcore/Reduce.h" #include "segcore/Reduce.h"
#include "segcore/reduce_c.h" #include "segcore/reduce_c.h"
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "futures/Future.h"
#include "test_utils/DataGen.h" #include "test_utils/DataGen.h"
#include "test_utils/PbHelper.h" #include "test_utils/PbHelper.h"
#include "test_utils/indexbuilder_test_utils.h" #include "test_utils/indexbuilder_test_utils.h"
@ -64,14 +65,50 @@ CStatus
CRetrieve(CSegmentInterface c_segment, CRetrieve(CSegmentInterface c_segment,
CRetrievePlan c_plan, CRetrievePlan c_plan,
uint64_t timestamp, uint64_t timestamp,
CRetrieveResult* result) { CRetrieveResult** result) {
return Retrieve({}, auto future = AsyncRetrieve(
c_segment, {}, c_segment, c_plan, timestamp, DEFAULT_MAX_OUTPUT_SIZE, false);
c_plan, auto futurePtr = static_cast<milvus::futures::IFuture*>(
timestamp, static_cast<void*>(static_cast<CFuture*>(future)));
result,
DEFAULT_MAX_OUTPUT_SIZE, std::mutex mu;
false); mu.lock();
futurePtr->registerReadyCallback(
[](CLockedGoMutex* mutex) { ((std::mutex*)(mutex))->unlock(); },
(CLockedGoMutex*)(&mu));
mu.lock();
auto [retrieveResult, status] = futurePtr->leakyGet();
if (status.error_code != 0) {
return status;
}
*result = static_cast<CRetrieveResult*>(retrieveResult);
return status;
}
CStatus
CRetrieveByOffsets(CSegmentInterface c_segment,
CRetrievePlan c_plan,
int64_t* offsets,
int64_t len,
CRetrieveResult** result) {
auto future = AsyncRetrieveByOffsets({}, c_segment, c_plan, offsets, len);
auto futurePtr = static_cast<milvus::futures::IFuture*>(
static_cast<void*>(static_cast<CFuture*>(future)));
std::mutex mu;
mu.lock();
futurePtr->registerReadyCallback(
[](CLockedGoMutex* mutex) { ((std::mutex*)(mutex))->unlock(); },
(CLockedGoMutex*)(&mu));
mu.lock();
auto [retrieveResult, status] = futurePtr->leakyGet();
if (status.error_code != 0) {
return status;
}
*result = static_cast<CRetrieveResult*>(retrieveResult);
return status;
} }
const char* const char*
@ -609,15 +646,16 @@ TEST(CApiTest, MultiDeleteGrowingSegment) {
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
auto max_ts = dataset.timestamps_[N - 1] + 10; auto max_ts = dataset.timestamps_[N - 1] + 10;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>(); auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, auto suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 0); ASSERT_EQ(query_result->ids().int_id().data().size(), 0);
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
// retrieve pks = {2} // retrieve pks = {2}
{ {
@ -633,11 +671,12 @@ TEST(CApiTest, MultiDeleteGrowingSegment) {
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, term_expr); std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, term_expr);
res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob, suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 1); ASSERT_EQ(query_result->ids().int_id().data().size(), 1);
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
// delete pks = {2} // delete pks = {2}
delete_pks = {2}; delete_pks = {2};
@ -658,13 +697,13 @@ TEST(CApiTest, MultiDeleteGrowingSegment) {
// retrieve pks in {2} // retrieve pks in {2}
res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob, suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 0); ASSERT_EQ(query_result->ids().int_id().data().size(), 0);
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
DeleteCollection(collection); DeleteCollection(collection);
DeleteSegment(segment); DeleteSegment(segment);
@ -721,15 +760,16 @@ TEST(CApiTest, MultiDeleteSealedSegment) {
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
auto max_ts = dataset.timestamps_[N - 1] + 10; auto max_ts = dataset.timestamps_[N - 1] + 10;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
auto res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); auto res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>(); auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, auto suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 0); ASSERT_EQ(query_result->ids().int_id().data().size(), 0);
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
// retrieve pks = {2} // retrieve pks = {2}
{ {
@ -745,11 +785,12 @@ TEST(CApiTest, MultiDeleteSealedSegment) {
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, term_expr); std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, term_expr);
res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob, suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 1); ASSERT_EQ(query_result->ids().int_id().data().size(), 1);
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
// delete pks = {2} // delete pks = {2}
delete_pks = {2}; delete_pks = {2};
@ -770,13 +811,13 @@ TEST(CApiTest, MultiDeleteSealedSegment) {
// retrieve pks in {2} // retrieve pks in {2}
res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result); res = CRetrieve(segment, plan.get(), max_ts, &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
suc = query_result->ParseFromArray(retrieve_result.proto_blob, suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 0); ASSERT_EQ(query_result->ids().int_id().data().size(), 0);
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
DeleteCollection(collection); DeleteCollection(collection);
DeleteSegment(segment); DeleteSegment(segment);
@ -839,16 +880,17 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) {
std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)}; std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)};
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
res = CRetrieve( res = CRetrieve(
segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>(); auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, auto suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 6); ASSERT_EQ(query_result->ids().int_id().data().size(), 6);
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
// delete data pks = {1, 2, 3} // delete data pks = {1, 2, 3}
std::vector<int64_t> delete_row_ids = {1, 2, 3}; std::vector<int64_t> delete_row_ids = {1, 2, 3};
@ -873,13 +915,14 @@ TEST(CApiTest, DeleteRepeatedPksFromGrowingSegment) {
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
query_result = std::make_unique<proto::segcore::RetrieveResults>(); query_result = std::make_unique<proto::segcore::RetrieveResults>();
suc = query_result->ParseFromArray(retrieve_result.proto_blob, suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 0); ASSERT_EQ(query_result->ids().int_id().data().size(), 0);
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
DeleteCollection(collection); DeleteCollection(collection);
DeleteSegment(segment); DeleteSegment(segment);
@ -920,16 +963,17 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) {
std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)}; std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)};
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
auto res = CRetrieve( auto res = CRetrieve(
segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>(); auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, auto suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 6); ASSERT_EQ(query_result->ids().int_id().data().size(), 6);
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
// delete data pks = {1, 2, 3} // delete data pks = {1, 2, 3}
std::vector<int64_t> delete_row_ids = {1, 2, 3}; std::vector<int64_t> delete_row_ids = {1, 2, 3};
@ -955,13 +999,13 @@ TEST(CApiTest, DeleteRepeatedPksFromSealedSegment) {
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
query_result = std::make_unique<proto::segcore::RetrieveResults>(); query_result = std::make_unique<proto::segcore::RetrieveResults>();
suc = query_result->ParseFromArray(retrieve_result.proto_blob, suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 0); ASSERT_EQ(query_result->ids().int_id().data().size(), 0);
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
DeleteCollection(collection); DeleteCollection(collection);
DeleteSegment(segment); DeleteSegment(segment);
@ -1030,16 +1074,17 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) {
std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)}; std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)};
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
res = CRetrieve( res = CRetrieve(
segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>(); auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, auto suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 0); ASSERT_EQ(query_result->ids().int_id().data().size(), 0);
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
// second insert data // second insert data
// insert data with pks = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} , timestamps = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19} // insert data with pks = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} , timestamps = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19}
@ -1061,13 +1106,13 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnGrowingSegment) {
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
query_result = std::make_unique<proto::segcore::RetrieveResults>(); query_result = std::make_unique<proto::segcore::RetrieveResults>();
suc = query_result->ParseFromArray(retrieve_result.proto_blob, suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 3); ASSERT_EQ(query_result->ids().int_id().data().size(), 3);
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
DeleteCollection(collection); DeleteCollection(collection);
DeleteSegment(segment); DeleteSegment(segment);
@ -1127,18 +1172,19 @@ TEST(CApiTest, InsertSamePkAfterDeleteOnSealedSegment) {
std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)}; std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)};
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
auto res = CRetrieve( auto res = CRetrieve(
segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result); segment, plan.get(), dataset.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>(); auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, auto suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->ids().int_id().data().size(), 4); ASSERT_EQ(query_result->ids().int_id().data().size(), 4);
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
retrieve_result = nullptr;
DeleteCollection(collection); DeleteCollection(collection);
DeleteSegment(segment); DeleteSegment(segment);
@ -1324,13 +1370,21 @@ TEST(CApiTest, RetrieveTestWithExpr) {
std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)}; std::vector<FieldId> target_field_ids{FieldId(100), FieldId(101)};
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
auto res = CRetrieve( auto res = CRetrieve(
segment, plan.get(), dataset.timestamps_[0], &retrieve_result); segment, plan.get(), dataset.timestamps_[0], &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
// Test Retrieve by offsets.
int64_t offsets[] = {0, 1, 2};
CRetrieveResult* retrieve_by_offsets_result = nullptr;
res = CRetrieveByOffsets(
segment, plan.get(), offsets, 3, &retrieve_by_offsets_result);
ASSERT_EQ(res.error_code, Success);
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
DeleteRetrieveResult(retrieve_by_offsets_result);
DeleteCollection(collection); DeleteCollection(collection);
DeleteSegment(segment); DeleteSegment(segment);
} }
@ -4324,13 +4378,13 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) {
i8_fid, i16_fid, i32_fid, i64_fid, float_fid, double_fid}; i8_fid, i16_fid, i32_fid, i64_fid, float_fid, double_fid};
plan->field_ids_ = target_field_ids; plan->field_ids_ = target_field_ids;
CRetrieveResult retrieve_result; CRetrieveResult* retrieve_result = nullptr;
res = CRetrieve( res = CRetrieve(
segment, plan.get(), raw_data.timestamps_[N - 1], &retrieve_result); segment, plan.get(), raw_data.timestamps_[N - 1], &retrieve_result);
ASSERT_EQ(res.error_code, Success); ASSERT_EQ(res.error_code, Success);
auto query_result = std::make_unique<proto::segcore::RetrieveResults>(); auto query_result = std::make_unique<proto::segcore::RetrieveResults>();
auto suc = query_result->ParseFromArray(retrieve_result.proto_blob, auto suc = query_result->ParseFromArray(retrieve_result->proto_blob,
retrieve_result.proto_size); retrieve_result->proto_size);
ASSERT_TRUE(suc); ASSERT_TRUE(suc);
ASSERT_EQ(query_result->fields_data().size(), 6); ASSERT_EQ(query_result->fields_data().size(), 6);
auto fields_data = query_result->fields_data(); auto fields_data = query_result->fields_data();
@ -4369,7 +4423,7 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) {
} }
DeleteRetrievePlan(plan.release()); DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result); DeleteRetrieveResult(retrieve_result);
DeleteSegment(segment); DeleteSegment(segment);
} }

View File

@ -206,5 +206,6 @@ TEST(Futures, Future) {
ASSERT_EQ(r, nullptr); ASSERT_EQ(r, nullptr);
ASSERT_EQ(s.error_code, milvus::FollyCancel); ASSERT_EQ(s.error_code, milvus::FollyCancel);
free((char*)(s.error_msg));
} }
} }

View File

@ -609,10 +609,10 @@ TEST(GroupBY, Reduce) {
CSearchResult c_search_res_1; CSearchResult c_search_res_1;
CSearchResult c_search_res_2; CSearchResult c_search_res_2;
auto status = auto status =
Search({}, c_segment_1, c_plan, c_ph_group, 1L << 63, &c_search_res_1); CSearch(c_segment_1, c_plan, c_ph_group, 1L << 63, &c_search_res_1);
ASSERT_EQ(status.error_code, Success); ASSERT_EQ(status.error_code, Success);
status = status =
Search({}, c_segment_2, c_plan, c_ph_group, 1L << 63, &c_search_res_2); CSearch(c_segment_2, c_plan, c_ph_group, 1L << 63, &c_search_res_2);
ASSERT_EQ(status.error_code, Success); ASSERT_EQ(status.error_code, Success);
std::vector<CSearchResult> results; std::vector<CSearchResult> results;
results.push_back(c_search_res_1); results.push_back(c_search_res_1);

View File

@ -28,6 +28,7 @@
#include "segcore/Reduce.h" #include "segcore/Reduce.h"
#include "segcore/reduce_c.h" #include "segcore/reduce_c.h"
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "futures/Future.h"
#include "DataGen.h" #include "DataGen.h"
#include "PbHelper.h" #include "PbHelper.h"
#include "c_api_test_utils.h" #include "c_api_test_utils.h"
@ -147,8 +148,24 @@ CSearch(CSegmentInterface c_segment,
CPlaceholderGroup c_placeholder_group, CPlaceholderGroup c_placeholder_group,
uint64_t timestamp, uint64_t timestamp,
CSearchResult* result) { CSearchResult* result) {
return Search( auto future =
{}, c_segment, c_plan, c_placeholder_group, timestamp, result); AsyncSearch({}, c_segment, c_plan, c_placeholder_group, timestamp);
auto futurePtr = static_cast<milvus::futures::IFuture*>(
static_cast<void*>(static_cast<CFuture*>(future)));
std::mutex mu;
mu.lock();
futurePtr->registerReadyCallback(
[](CLockedGoMutex* mutex) { ((std::mutex*)(mutex))->unlock(); },
(CLockedGoMutex*)(&mu));
mu.lock();
auto [searchResult, status] = futurePtr->leakyGet();
if (status.error_code != 0) {
return status;
}
*result = static_cast<CSearchResult>(searchResult);
return status;
} }
} // namespace } // namespace

View File

@ -28,6 +28,7 @@ import "C"
import ( import (
"context" "context"
"math"
"unsafe" "unsafe"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -55,14 +56,9 @@ func HandleCStatus(ctx context.Context, status *C.CStatus, extraInfo string, fie
return err return err
} }
// HandleCProto deal with the result proto returned from CGO // UnmarshalCProto unmarshal the proto from C memory
func HandleCProto(cRes *C.CProto, msg proto.Message) error { func UnmarshalCProto(cRes *C.CProto, msg proto.Message) error {
// Standalone CProto is protobuf created by C side, blob := (*(*[math.MaxInt32]byte)(cRes.proto_blob))[:int(cRes.proto_size):int(cRes.proto_size)]
// Passed from c side
// memory is managed manually
lease, blob := cgoconverter.UnsafeGoBytes(&cRes.proto_blob, int(cRes.proto_size))
defer cgoconverter.Release(lease)
return proto.Unmarshal(blob, msg) return proto.Unmarshal(blob, msg)
} }

View File

@ -163,6 +163,10 @@ func (suite *RetrieveSuite) TestRetrieveSealed() {
suite.NoError(err) suite.NoError(err)
suite.Len(res[0].Result.Offset, 3) suite.Len(res[0].Result.Offset, 3)
suite.manager.Segment.Unpin(segments) suite.manager.Segment.Unpin(segments)
resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), plan, []int64{0, 1})
suite.NoError(err)
suite.Len(resultByOffsets.Offset, 0)
} }
func (suite *RetrieveSuite) TestRetrieveGrowing() { func (suite *RetrieveSuite) TestRetrieveGrowing() {
@ -182,6 +186,10 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() {
suite.NoError(err) suite.NoError(err)
suite.Len(res[0].Result.Offset, 3) suite.Len(res[0].Result.Offset, 3)
suite.manager.Segment.Unpin(segments) suite.manager.Segment.Unpin(segments)
resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), plan, []int64{0, 1})
suite.NoError(err)
suite.Len(resultByOffsets.Offset, 0)
} }
func (suite *RetrieveSuite) TestRetrieveStreamSealed() { func (suite *RetrieveSuite) TestRetrieveStreamSealed() {

View File

@ -17,8 +17,9 @@
package segments package segments
/* /*
#cgo pkg-config: milvus_segcore #cgo pkg-config: milvus_segcore milvus_futures
#include "futures/future_c.h"
#include "segcore/collection_c.h" #include "segcore/collection_c.h"
#include "segcore/plan_c.h" #include "segcore/plan_c.h"
#include "segcore/reduce_c.h" #include "segcore/reduce_c.h"
@ -53,6 +54,7 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments/state" "github.com/milvus-io/milvus/internal/querynodev2/segments/state"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/cgo"
typeutil_internal "github.com/milvus-io/milvus/internal/util/typeutil" typeutil_internal "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -565,34 +567,39 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
traceCtx := ParseCTraceContext(ctx) traceCtx := ParseCTraceContext(ctx)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(searchReq)
hasIndex := s.ExistIndex(searchReq.searchFieldID) hasIndex := s.ExistIndex(searchReq.searchFieldID)
log = log.With(zap.Bool("withIndex", hasIndex)) log = log.With(zap.Bool("withIndex", hasIndex))
log.Debug("search segment...") log.Debug("search segment...")
var searchResult SearchResult
var status C.CStatus
GetSQPool().Submit(func() (any, error) {
tr := timerecord.NewTimeRecorder("cgoSearch") tr := timerecord.NewTimeRecorder("cgoSearch")
status = C.Search(traceCtx.ctx,
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncSearch(
traceCtx.ctx,
s.ptr, s.ptr,
searchReq.plan.cSearchPlan, searchReq.plan.cSearchPlan,
searchReq.cPlaceholderGroup, searchReq.cPlaceholderGroup,
C.uint64_t(searchReq.mvccTimestamp), C.uint64_t(searchReq.mvccTimestamp),
&searchResult.cSearchResult, ))
},
cgo.WithName("search"),
) )
runtime.KeepAlive(traceCtx) defer future.Release()
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) result, err := future.BlockAndLeakyGet()
return nil, nil if err != nil {
}).Await() log.Warn("Search failed")
if err := HandleCStatus(ctx, &status, "Search failed",
zap.Int64("collectionID", s.Collection()),
zap.Int64("segmentID", s.ID()),
zap.String("segmentType", s.segmentType.String())); err != nil {
return nil, err return nil, err
} }
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("search segment done") log.Debug("search segment done")
return &searchResult, nil return &SearchResult{
cSearchResult: (C.CSearchResult)(result),
}, nil
} }
func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
@ -612,69 +619,65 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
log.Debug("begin to retrieve") log.Debug("begin to retrieve")
traceCtx := ParseCTraceContext(ctx) traceCtx := ParseCTraceContext(ctx)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(plan)
maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
var retrieveResult RetrieveResult
var status C.CStatus
GetSQPool().Submit(func() (any, error) {
ts := C.uint64_t(plan.Timestamp)
tr := timerecord.NewTimeRecorder("cgoRetrieve") tr := timerecord.NewTimeRecorder("cgoRetrieve")
status = C.Retrieve(traceCtx.ctx,
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncRetrieve(
traceCtx.ctx,
s.ptr, s.ptr,
plan.cRetrievePlan, plan.cRetrievePlan,
ts, C.uint64_t(plan.Timestamp),
&retrieveResult.cRetrieveResult,
C.int64_t(maxLimitSize), C.int64_t(maxLimitSize),
C.bool(plan.ignoreNonPk)) C.bool(plan.ignoreNonPk),
runtime.KeepAlive(traceCtx) ))
},
cgo.WithName("retrieve"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil {
log.Warn("Retrieve failed")
return nil, err
}
defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("cgo retrieve done", zap.Duration("timeTaken", tr.ElapseSpan()))
return nil, nil
}).Await()
if err := HandleCStatus(ctx, &status, "Retrieve failed",
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.segmentType.String())); err != nil {
return nil, err
}
_, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "partial-segcore-results-deserialization") _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "partial-segcore-results-deserialization")
defer span.End() defer span.End()
result := new(segcorepb.RetrieveResults) retrieveResult := new(segcorepb.RetrieveResults)
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil { if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil {
log.Warn("unmarshal retrieve result failed", zap.Error(err))
return nil, err return nil, err
} }
log.Debug("retrieve segment done", log.Debug("retrieve segment done",
zap.Int("resultNum", len(result.Offset)), zap.Int("resultNum", len(retrieveResult.Offset)),
) )
// Sort was done by the segcore. // Sort was done by the segcore.
// sort.Sort(&byPK{result}) // sort.Sort(&byPK{result})
return result, nil return retrieveResult, nil
} }
func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) {
if len(offsets) == 0 {
return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets")
}
if !s.ptrLock.RLockIf(state.IsNotReleased) { if !s.ptrLock.RLockIf(state.IsNotReleased) {
// TODO: check if the segment is readable but not released. too many related logic need to be refactor. // TODO: check if the segment is readable but not released. too many related logic need to be refactor.
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
if s.ptr == nil {
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
}
if len(offsets) == 0 {
return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets")
}
fields := []zap.Field{ fields := []zap.Field{
zap.Int64("collectionID", s.Collection()), zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()), zap.Int64("partitionID", s.Partition()),
@ -686,40 +689,49 @@ func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan
log := log.Ctx(ctx).With(fields...) log := log.Ctx(ctx).With(fields...)
log.Debug("begin to retrieve by offsets") log.Debug("begin to retrieve by offsets")
traceCtx := ParseCTraceContext(ctx)
var retrieveResult RetrieveResult
var status C.CStatus
tr := timerecord.NewTimeRecorder("cgoRetrieveByOffsets") tr := timerecord.NewTimeRecorder("cgoRetrieveByOffsets")
status = C.RetrieveByOffsets(traceCtx.ctx, traceCtx := ParseCTraceContext(ctx)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(plan)
defer runtime.KeepAlive(offsets)
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncRetrieveByOffsets(
traceCtx.ctx,
s.ptr, s.ptr,
plan.cRetrievePlan, plan.cRetrievePlan,
&retrieveResult.cRetrieveResult,
(*C.int64_t)(unsafe.Pointer(&offsets[0])), (*C.int64_t)(unsafe.Pointer(&offsets[0])),
C.int64_t(len(offsets))) C.int64_t(len(offsets)),
runtime.KeepAlive(traceCtx) ))
},
cgo.WithName("retrieve-by-offsets"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil {
log.Warn("RetrieveByOffsets failed")
return nil, err
}
defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("cgo retrieve by offsets done", zap.Duration("timeTaken", tr.ElapseSpan()))
if err := HandleCStatus(ctx, &status, "RetrieveByOffsets failed", fields...); err != nil {
return nil, err
}
_, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "reduced-segcore-results-deserialization") _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "reduced-segcore-results-deserialization")
defer span.End() defer span.End()
result := new(segcorepb.RetrieveResults) retrieveResult := new(segcorepb.RetrieveResults)
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil { if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil {
log.Warn("unmarshal retrieve by offsets result failed", zap.Error(err))
return nil, err return nil, err
} }
log.Debug("retrieve by segment offsets done") log.Debug("retrieve by segment offsets done",
zap.Int("resultNum", len(retrieveResult.Offset)),
return result, nil )
return retrieveResult, nil
} }
func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (dataPath string, offsetInBinlog int64) { func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (dataPath string, offsetInBinlog int64) {

View File

@ -0,0 +1,36 @@
package cgo
/*
#cgo pkg-config: milvus_futures
#include "futures/future_c.h"
*/
import "C"
import (
"math"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
// initExecutor initialize underlying cgo thread pool.
func initExecutor() {
pt := paramtable.Get()
initPoolSize := int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat()))
C.executor_set_thread_num(C.int(initPoolSize))
resetThreadNum := func(evt *config.Event) {
if evt.HasUpdated {
pt := paramtable.Get()
newSize := int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat()))
log.Info("reset cgo thread num", zap.Int("thread_num", newSize))
C.executor_set_thread_num(C.int(newSize))
}
}
pt.Watch(pt.QueryNodeCfg.MaxReadConcurrency.Key, config.NewHandler("cgo."+pt.QueryNodeCfg.MaxReadConcurrency.Key, resetThreadNum))
pt.Watch(pt.QueryNodeCfg.CGOPoolSizeRatio.Key, config.NewHandler("cgo."+pt.QueryNodeCfg.CGOPoolSizeRatio.Key, resetThreadNum))
}

View File

@ -75,6 +75,8 @@ type (
// Async is a helper function to call a C async function that returns a future. // Async is a helper function to call a C async function that returns a future.
func Async(ctx context.Context, f CGOAsyncFunction, opts ...Opt) Future { func Async(ctx context.Context, f CGOAsyncFunction, opts ...Opt) Future {
initCGO()
options := getDefaultOpt() options := getDefaultOpt()
// apply options. // apply options.
for _, opt := range opts { for _, opt := range opts {

View File

@ -18,7 +18,7 @@ import (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
paramtable.Init() paramtable.Init()
InitCGO() initCGO()
exitCode := m.Run() exitCode := m.Run()
if exitCode > 0 { if exitCode > 0 {
os.Exit(exitCode) os.Exit(exitCode)
@ -47,10 +47,6 @@ func TestFutureWithSuccessCase(t *testing.T) {
_, err = future.BlockAndLeakyGet() _, err = future.BlockAndLeakyGet()
assert.ErrorIs(t, err, ErrConsumed) assert.ErrorIs(t, err, ErrConsumed)
assert.Eventually(t, func() bool {
return unreleasedCnt.Load() == 0
}, time.Second, time.Millisecond*100)
} }
func TestFutureWithCaseNoInterrupt(t *testing.T) { func TestFutureWithCaseNoInterrupt(t *testing.T) {
@ -186,10 +182,6 @@ func TestFutures(t *testing.T) {
assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.Nil(t, result) assert.Nil(t, result)
runtime.GC() runtime.GC()
assert.Eventually(t, func() bool {
return unreleasedCnt.Load() == 0
}, time.Second, time.Millisecond*100)
} }
func TestConcurrent(t *testing.T) { func TestConcurrent(t *testing.T) {
@ -259,8 +251,14 @@ func TestConcurrent(t *testing.T) {
}) })
defer future.Release() defer future.Release()
result, err := future.BlockAndLeakyGet() result, err := future.BlockAndLeakyGet()
assert.NoError(t, err) if err == nil {
assert.Equal(t, 0, getCInt(result)) assert.Equal(t, 0, getCInt(result))
} else {
// the future may be queued and not started,
// so the underlying task may be throw a cancel exception if it's not started.
assert.ErrorIs(t, err, merr.ErrSegcoreFollyCancel)
assert.True(t, errors.Is(err, context.DeadlineExceeded))
}
freeCInt(result) freeCInt(result)
}() }()
} }
@ -271,8 +269,4 @@ func TestConcurrent(t *testing.T) {
return stat.ActiveCount == 0 return stat.ActiveCount == 0
}, 5*time.Second, 100*time.Millisecond) }, 5*time.Second, 100*time.Millisecond)
runtime.GC() runtime.GC()
assert.Eventually(t, func() bool {
return unreleasedCnt.Load() == 0
}, time.Second, time.Millisecond*100)
} }

View File

@ -16,8 +16,6 @@ import (
"context" "context"
"time" "time"
"unsafe" "unsafe"
"go.uber.org/atomic"
) )
const ( const (
@ -27,8 +25,6 @@ const (
caseNoThrowSegcoreException int = 3 caseNoThrowSegcoreException int = 3
) )
var unreleasedCnt = atomic.NewInt32(0)
type testCase struct { type testCase struct {
interval time.Duration interval time.Duration
loopCnt int loopCnt int
@ -39,12 +35,7 @@ func createFutureWithTestCase(ctx context.Context, testCase testCase) Future {
f := func() CFuturePtr { f := func() CFuturePtr {
return (CFuturePtr)(C.future_create_test_case(C.int(testCase.interval.Milliseconds()), C.int(testCase.loopCnt), C.int(testCase.caseNo))) return (CFuturePtr)(C.future_create_test_case(C.int(testCase.interval.Milliseconds()), C.int(testCase.loopCnt), C.int(testCase.caseNo)))
} }
future := Async(ctx, f, future := Async(ctx, f, WithName("createFutureWithTestCase"))
WithName("createFutureWithTestCase"),
WithReleaser(func() {
unreleasedCnt.Dec()
}))
unreleasedCnt.Inc()
return future return future
} }

View File

@ -1,14 +1,12 @@
package cgo package cgo
import ( import (
"math"
"reflect" "reflect"
"sync" "sync"
"go.uber.org/atomic" "go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/hardware"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
) )
@ -23,20 +21,12 @@ var (
initOnce sync.Once initOnce sync.Once
) )
// InitCGO initializes the cgo caller and future manager. // initCGO initializes the cgo caller and future manager.
// Please call this function before using any cgo utilities. func initCGO() {
func InitCGO() {
initOnce.Do(func() { initOnce.Do(func() {
nodeID := paramtable.GetStringNodeID() nodeID := paramtable.GetStringNodeID()
chSize := int64(math.Ceil(float64(hardware.GetCPUNum()) * paramtable.Get().QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) initCaller(nodeID)
if chSize <= 0 { initExecutor()
chSize = 1
}
caller = &cgoCaller{
// TODO: temporary solution, need to find a better way to set the pool size.
ch: make(chan struct{}, chSize),
nodeID: nodeID,
}
futureManager = newActiveFutureManager(nodeID) futureManager = newActiveFutureManager(nodeID)
futureManager.Run() futureManager.Run()
}) })

View File

@ -1,14 +1,28 @@
package cgo package cgo
import ( import (
"math"
"runtime" "runtime"
"time" "time"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/hardware"
"github.com/milvus-io/milvus/pkg/util/paramtable"
) )
var caller *cgoCaller var caller *cgoCaller
func initCaller(nodeID string) {
chSize := int64(math.Ceil(float64(hardware.GetCPUNum()) * paramtable.Get().QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat()))
if chSize <= 0 {
chSize = 1
}
caller = &cgoCaller{
ch: make(chan struct{}, chSize),
nodeID: nodeID,
}
}
// getCGOCaller returns the cgoCaller instance. // getCGOCaller returns the cgoCaller instance.
func getCGOCaller() *cgoCaller { func getCGOCaller() *cgoCaller {
return caller return caller

View File

@ -2,6 +2,7 @@ package metrics
import ( import (
"sync" "sync"
"time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
@ -10,6 +11,22 @@ var (
subsystemCGO = "cgo" subsystemCGO = "cgo"
cgoLabelName = "name" cgoLabelName = "name"
once sync.Once once sync.Once
bucketsForCGOCall = []float64{
10 * time.Nanosecond.Seconds(),
100 * time.Nanosecond.Seconds(),
250 * time.Nanosecond.Seconds(),
500 * time.Nanosecond.Seconds(),
time.Microsecond.Seconds(),
10 * time.Microsecond.Seconds(),
20 * time.Microsecond.Seconds(),
50 * time.Microsecond.Seconds(),
100 * time.Microsecond.Seconds(),
250 * time.Microsecond.Seconds(),
500 * time.Microsecond.Seconds(),
time.Millisecond.Seconds(),
2 * time.Millisecond.Seconds(),
10 * time.Millisecond.Seconds(),
}
ActiveFutureTotal = prometheus.NewGaugeVec( ActiveFutureTotal = prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
@ -38,6 +55,7 @@ var (
Subsystem: subsystemCGO, Subsystem: subsystemCGO,
Name: "cgo_duration_seconds", Name: "cgo_duration_seconds",
Help: "Histogram of cgo call duration in seconds.", Help: "Histogram of cgo call duration in seconds.",
Buckets: bucketsForCGOCall,
}, []string{ }, []string{
nodeIDLabelName, nodeIDLabelName,
cgoLabelName, cgoLabelName,
@ -50,6 +68,7 @@ var (
Subsystem: subsystemCGO, Subsystem: subsystemCGO,
Name: "cgo_queue_duration_seconds", Name: "cgo_queue_duration_seconds",
Help: "Duration of cgo call in queue.", Help: "Duration of cgo call in queue.",
Buckets: bucketsForCGOCall,
}, []string{ }, []string{
nodeIDLabelName, nodeIDLabelName,
}, },
@ -59,8 +78,9 @@ var (
// RegisterCGOMetrics registers the cgo metrics. // RegisterCGOMetrics registers the cgo metrics.
func RegisterCGOMetrics(registry *prometheus.Registry) { func RegisterCGOMetrics(registry *prometheus.Registry) {
once.Do(func() { once.Do(func() {
prometheus.MustRegister(RunningCgoCallTotal) registry.MustRegister(ActiveFutureTotal)
prometheus.MustRegister(CGODuration) registry.MustRegister(RunningCgoCallTotal)
prometheus.MustRegister(CGOQueueDuration) registry.MustRegister(CGODuration)
registry.MustRegister(CGOQueueDuration)
}) })
} }

View File

@ -38,6 +38,7 @@ func TestRegisterMetrics(t *testing.T) {
RegisterMetaMetrics(r) RegisterMetaMetrics(r)
RegisterStorageMetrics(r) RegisterStorageMetrics(r)
RegisterMsgStreamMetrics(r) RegisterMsgStreamMetrics(r)
RegisterCGOMetrics(r)
}) })
} }

View File

@ -792,6 +792,8 @@ func RegisterQueryNode(registry *prometheus.Registry) {
registry.MustRegister(QueryNodeSegmentPruneRatio) registry.MustRegister(QueryNodeSegmentPruneRatio)
registry.MustRegister(QueryNodeApplyBFCost) registry.MustRegister(QueryNodeApplyBFCost)
registry.MustRegister(QueryNodeForwardDeleteCost) registry.MustRegister(QueryNodeForwardDeleteCost)
// Add cgo metrics
RegisterCGOMetrics(registry)
} }
func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) { func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) {