diff --git a/internal/core/src/common/CGoHelper.h b/internal/core/src/common/CGoHelper.h index a0707f7c0f..9e892b0b33 100644 --- a/internal/core/src/common/CGoHelper.h +++ b/internal/core/src/common/CGoHelper.h @@ -16,14 +16,6 @@ namespace milvus { -inline CProtoResult -AllocCProtoResult(const google::protobuf::Message& msg) { - auto size = msg.ByteSize(); - void* buffer = malloc(size); - msg.SerializePartialToArray(buffer, size); - return CProtoResult{CStatus{Success}, CProto{buffer, size}}; -} - inline CStatus SuccessCStatus() { return CStatus{Success, ""}; diff --git a/internal/core/src/common/type_c.h b/internal/core/src/common/type_c.h index f521332b98..41cf4526d1 100644 --- a/internal/core/src/common/type_c.h +++ b/internal/core/src/common/type_c.h @@ -53,11 +53,6 @@ typedef struct CLoadDeletedRecordInfo { int64_t row_count; } CLoadDeletedRecordInfo; -typedef struct CProtoResult { - CStatus status; - CProto proto; -} CProtoResult; - #ifdef __cplusplus } #endif diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index 4ae6d45eb2..5a77c3d261 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -87,6 +87,30 @@ Search(CSegmentInterface c_segment, } } +void +DeleteRetrieveResult(CRetrieveResult* retrieve_result) { + std::free((void*)(retrieve_result->proto_blob)); +} + +CStatus +Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp, CRetrieveResult* result) { + try { + auto segment = (const milvus::segcore::SegmentInterface*)c_segment; + auto plan = (const milvus::query::RetrievePlan*)c_plan; + auto retrieve_result = segment->Retrieve(plan, timestamp); + + auto size = retrieve_result->ByteSize(); + void* buffer = malloc(size); + retrieve_result->SerializePartialToArray(buffer, size); + + result->proto_blob = buffer; + result->proto_size = size; + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(UnexpectedError, e.what()); + } +} + int64_t GetMemoryUsageInBytes(CSegmentInterface c_segment) { auto segment = (milvus::segcore::SegmentInterface*)c_segment; @@ -237,15 +261,3 @@ DropSealedSegmentIndex(CSegmentInterface c_segment, int64_t field_id) { return milvus::FailureCStatus(UnexpectedError, e.what()); } } - -CProtoResult -Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp) { - try { - auto segment = (const milvus::segcore::SegmentInterface*)c_segment; - auto plan = (const milvus::query::RetrievePlan*)c_plan; - auto result = segment->Retrieve(plan, timestamp); - return milvus::AllocCProtoResult(*result); - } catch (std::exception& e) { - return CProtoResult{milvus::FailureCStatus(UnexpectedError, e.what())}; - } -} diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index f65da26b21..024595fff0 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -25,7 +25,7 @@ extern "C" { typedef void* CSegmentInterface; typedef void* CSearchResult; -typedef void* CRetrieveResult; +typedef CProto CRetrieveResult; ////////////////////////////// common interfaces ////////////////////////////// CSegmentInterface @@ -44,8 +44,11 @@ Search(CSegmentInterface c_segment, uint64_t timestamp, CSearchResult* result); -CProtoResult -Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp); +void +DeleteRetrieveResult(CRetrieveResult* retrieve_result); + +CStatus +Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp, CRetrieveResult* result); int64_t GetMemoryUsageInBytes(CSegmentInterface c_segment); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 44eaaca755..df7eb6ef2c 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -24,6 +24,7 @@ #include "index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h" #include "pb/milvus.pb.h" #include "pb/plan.pb.h" +#include "query/ExprImpl.h" #include "segcore/Collection.h" #include "segcore/reduce_c.h" #include "test_utils/DataGen.h" @@ -351,6 +352,44 @@ TEST(CApiTest, SearchTestWithExpr) { DeleteSegment(segment); } +TEST(CApiTest, RetrieveTestWithExpr) { + auto collection = NewCollection(get_default_schema_config()); + auto segment = NewSegment(collection, 0, Growing); + + int N = 10000; + auto [raw_data, timestamps, uids] = generate_data(N); + auto line_sizeof = (sizeof(int) + sizeof(float) * DIM); + + int64_t offset; + PreInsert(segment, N, &offset); + + auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); + ASSERT_EQ(ins_res.error_code, Success); + + auto schema = ((milvus::segcore::Collection*)collection)->get_schema(); + auto plan = std::make_unique(*schema); + + // create retrieve plan "age in [0]" + auto term_expr = std::make_unique>(); + term_expr->field_offset_ = FieldOffset(1); + term_expr->data_type_ = DataType::INT32; + term_expr->terms_.emplace_back(0); + + plan->plan_node_ = std::make_unique(); + plan->plan_node_->predicate_ = std::move(term_expr); + std::vector target_offsets{FieldOffset(0), FieldOffset(1)}; + plan->field_offsets_ = target_offsets; + + CRetrieveResult retrieve_result; + auto res = Retrieve(segment, plan.release(), timestamps[0], &retrieve_result); + ASSERT_EQ(res.error_code, Success); + + DeleteRetrievePlan(plan.release()); + DeleteRetrieveResult(&retrieve_result); + DeleteCollection(collection); + DeleteSegment(segment); +} + TEST(CApiTest, GetMemoryUsageInBytesTest) { auto collection = NewCollection(get_default_schema_config()); auto segment = NewSegment(collection, 0, Growing); diff --git a/internal/querynode/cgo_helper.go b/internal/querynode/cgo_helper.go index e06b2997fc..a034fcf850 100644 --- a/internal/querynode/cgo_helper.go +++ b/internal/querynode/cgo_helper.go @@ -80,17 +80,12 @@ func HandleCStatus(status *C.CStatus, extraInfo string) error { return errors.New(finalMsg) } -// HandleCProtoResult deal with the result proto returned from CGO -func HandleCProtoResult(cRes *C.CProtoResult, msg proto.Message) error { +// HandleCProto deal with the result proto returned from CGO +func HandleCProto(cRes *C.CProto, msg proto.Message) error { // Standalone CProto is protobuf created by C side, // Passed from c side // memory is managed manually - err := HandleCStatus(&cRes.status, "") - if err != nil { - return err - } - cpro := cRes.proto - blob := C.GoBytes(unsafe.Pointer(cpro.proto_blob), C.int32_t(cpro.proto_size)) - defer C.free(cpro.proto_blob) + blob := C.GoBytes(unsafe.Pointer(cRes.proto_blob), C.int32_t(cRes.proto_size)) + defer C.free(cRes.proto_blob) return proto.Unmarshal(blob, msg) } diff --git a/internal/querynode/reduce.go b/internal/querynode/reduce.go index 0ca32c63ad..cef6a655f5 100644 --- a/internal/querynode/reduce.go +++ b/internal/querynode/reduce.go @@ -35,6 +35,11 @@ type MarshaledHits struct { cMarshaledHits C.CMarshaledHits } +// RetrieveResult contains a pointer to the retrieve result in C++ memory +type RetrieveResult struct { + cRetrieveResult C.CRetrieveResult +} + func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, numSegments int64) error { if plan.cSearchPlan == nil { return errors.New("nil search plan") diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index 7b35c2546c..23856d7516 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -316,10 +316,15 @@ func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, erro if s.segmentPtr == nil { return nil, errors.New("null seg core pointer") } - resProto := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, C.uint64_t(plan.Timestamp)) + + var retrieveResult RetrieveResult + ts := C.uint64_t(plan.Timestamp) + status := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, ts, &retrieveResult.cRetrieveResult) + if err := HandleCStatus(&status, "Retrieve failed"); err != nil { + return nil, err + } result := new(segcorepb.RetrieveResults) - err := HandleCProtoResult(&resProto, result) - if err != nil { + if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil { return nil, err } return result, nil