mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
Change API retrieve return type from CProtoResult to CProto (#11555)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
eb41afc661
commit
5fdc6626cb
@ -16,14 +16,6 @@
|
|||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
|
|
||||||
inline CProtoResult
|
|
||||||
AllocCProtoResult(const google::protobuf::Message& msg) {
|
|
||||||
auto size = msg.ByteSize();
|
|
||||||
void* buffer = malloc(size);
|
|
||||||
msg.SerializePartialToArray(buffer, size);
|
|
||||||
return CProtoResult{CStatus{Success}, CProto{buffer, size}};
|
|
||||||
}
|
|
||||||
|
|
||||||
inline CStatus
|
inline CStatus
|
||||||
SuccessCStatus() {
|
SuccessCStatus() {
|
||||||
return CStatus{Success, ""};
|
return CStatus{Success, ""};
|
||||||
|
|||||||
@ -53,11 +53,6 @@ typedef struct CLoadDeletedRecordInfo {
|
|||||||
int64_t row_count;
|
int64_t row_count;
|
||||||
} CLoadDeletedRecordInfo;
|
} CLoadDeletedRecordInfo;
|
||||||
|
|
||||||
typedef struct CProtoResult {
|
|
||||||
CStatus status;
|
|
||||||
CProto proto;
|
|
||||||
} CProtoResult;
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -87,6 +87,30 @@ Search(CSegmentInterface c_segment,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
DeleteRetrieveResult(CRetrieveResult* retrieve_result) {
|
||||||
|
std::free((void*)(retrieve_result->proto_blob));
|
||||||
|
}
|
||||||
|
|
||||||
|
CStatus
|
||||||
|
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp, CRetrieveResult* result) {
|
||||||
|
try {
|
||||||
|
auto segment = (const milvus::segcore::SegmentInterface*)c_segment;
|
||||||
|
auto plan = (const milvus::query::RetrievePlan*)c_plan;
|
||||||
|
auto retrieve_result = segment->Retrieve(plan, timestamp);
|
||||||
|
|
||||||
|
auto size = retrieve_result->ByteSize();
|
||||||
|
void* buffer = malloc(size);
|
||||||
|
retrieve_result->SerializePartialToArray(buffer, size);
|
||||||
|
|
||||||
|
result->proto_blob = buffer;
|
||||||
|
result->proto_size = size;
|
||||||
|
return milvus::SuccessCStatus();
|
||||||
|
} catch (std::exception& e) {
|
||||||
|
return milvus::FailureCStatus(UnexpectedError, e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetMemoryUsageInBytes(CSegmentInterface c_segment) {
|
GetMemoryUsageInBytes(CSegmentInterface c_segment) {
|
||||||
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
|
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
|
||||||
@ -237,15 +261,3 @@ DropSealedSegmentIndex(CSegmentInterface c_segment, int64_t field_id) {
|
|||||||
return milvus::FailureCStatus(UnexpectedError, e.what());
|
return milvus::FailureCStatus(UnexpectedError, e.what());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CProtoResult
|
|
||||||
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp) {
|
|
||||||
try {
|
|
||||||
auto segment = (const milvus::segcore::SegmentInterface*)c_segment;
|
|
||||||
auto plan = (const milvus::query::RetrievePlan*)c_plan;
|
|
||||||
auto result = segment->Retrieve(plan, timestamp);
|
|
||||||
return milvus::AllocCProtoResult(*result);
|
|
||||||
} catch (std::exception& e) {
|
|
||||||
return CProtoResult{milvus::FailureCStatus(UnexpectedError, e.what())};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ extern "C" {
|
|||||||
|
|
||||||
typedef void* CSegmentInterface;
|
typedef void* CSegmentInterface;
|
||||||
typedef void* CSearchResult;
|
typedef void* CSearchResult;
|
||||||
typedef void* CRetrieveResult;
|
typedef CProto CRetrieveResult;
|
||||||
|
|
||||||
////////////////////////////// common interfaces //////////////////////////////
|
////////////////////////////// common interfaces //////////////////////////////
|
||||||
CSegmentInterface
|
CSegmentInterface
|
||||||
@ -44,8 +44,11 @@ Search(CSegmentInterface c_segment,
|
|||||||
uint64_t timestamp,
|
uint64_t timestamp,
|
||||||
CSearchResult* result);
|
CSearchResult* result);
|
||||||
|
|
||||||
CProtoResult
|
void
|
||||||
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp);
|
DeleteRetrieveResult(CRetrieveResult* retrieve_result);
|
||||||
|
|
||||||
|
CStatus
|
||||||
|
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp, CRetrieveResult* result);
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetMemoryUsageInBytes(CSegmentInterface c_segment);
|
GetMemoryUsageInBytes(CSegmentInterface c_segment);
|
||||||
|
|||||||
@ -24,6 +24,7 @@
|
|||||||
#include "index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h"
|
#include "index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h"
|
||||||
#include "pb/milvus.pb.h"
|
#include "pb/milvus.pb.h"
|
||||||
#include "pb/plan.pb.h"
|
#include "pb/plan.pb.h"
|
||||||
|
#include "query/ExprImpl.h"
|
||||||
#include "segcore/Collection.h"
|
#include "segcore/Collection.h"
|
||||||
#include "segcore/reduce_c.h"
|
#include "segcore/reduce_c.h"
|
||||||
#include "test_utils/DataGen.h"
|
#include "test_utils/DataGen.h"
|
||||||
@ -351,6 +352,44 @@ TEST(CApiTest, SearchTestWithExpr) {
|
|||||||
DeleteSegment(segment);
|
DeleteSegment(segment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CApiTest, RetrieveTestWithExpr) {
|
||||||
|
auto collection = NewCollection(get_default_schema_config());
|
||||||
|
auto segment = NewSegment(collection, 0, Growing);
|
||||||
|
|
||||||
|
int N = 10000;
|
||||||
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
|
int64_t offset;
|
||||||
|
PreInsert(segment, N, &offset);
|
||||||
|
|
||||||
|
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||||
|
ASSERT_EQ(ins_res.error_code, Success);
|
||||||
|
|
||||||
|
auto schema = ((milvus::segcore::Collection*)collection)->get_schema();
|
||||||
|
auto plan = std::make_unique<query::RetrievePlan>(*schema);
|
||||||
|
|
||||||
|
// create retrieve plan "age in [0]"
|
||||||
|
auto term_expr = std::make_unique<query::TermExprImpl<int64_t>>();
|
||||||
|
term_expr->field_offset_ = FieldOffset(1);
|
||||||
|
term_expr->data_type_ = DataType::INT32;
|
||||||
|
term_expr->terms_.emplace_back(0);
|
||||||
|
|
||||||
|
plan->plan_node_ = std::make_unique<query::RetrievePlanNode>();
|
||||||
|
plan->plan_node_->predicate_ = std::move(term_expr);
|
||||||
|
std::vector<FieldOffset> target_offsets{FieldOffset(0), FieldOffset(1)};
|
||||||
|
plan->field_offsets_ = target_offsets;
|
||||||
|
|
||||||
|
CRetrieveResult retrieve_result;
|
||||||
|
auto res = Retrieve(segment, plan.release(), timestamps[0], &retrieve_result);
|
||||||
|
ASSERT_EQ(res.error_code, Success);
|
||||||
|
|
||||||
|
DeleteRetrievePlan(plan.release());
|
||||||
|
DeleteRetrieveResult(&retrieve_result);
|
||||||
|
DeleteCollection(collection);
|
||||||
|
DeleteSegment(segment);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CApiTest, GetMemoryUsageInBytesTest) {
|
TEST(CApiTest, GetMemoryUsageInBytesTest) {
|
||||||
auto collection = NewCollection(get_default_schema_config());
|
auto collection = NewCollection(get_default_schema_config());
|
||||||
auto segment = NewSegment(collection, 0, Growing);
|
auto segment = NewSegment(collection, 0, Growing);
|
||||||
|
|||||||
@ -80,17 +80,12 @@ func HandleCStatus(status *C.CStatus, extraInfo string) error {
|
|||||||
return errors.New(finalMsg)
|
return errors.New(finalMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleCProtoResult deal with the result proto returned from CGO
|
// HandleCProto deal with the result proto returned from CGO
|
||||||
func HandleCProtoResult(cRes *C.CProtoResult, msg proto.Message) error {
|
func HandleCProto(cRes *C.CProto, msg proto.Message) error {
|
||||||
// Standalone CProto is protobuf created by C side,
|
// Standalone CProto is protobuf created by C side,
|
||||||
// Passed from c side
|
// Passed from c side
|
||||||
// memory is managed manually
|
// memory is managed manually
|
||||||
err := HandleCStatus(&cRes.status, "")
|
blob := C.GoBytes(unsafe.Pointer(cRes.proto_blob), C.int32_t(cRes.proto_size))
|
||||||
if err != nil {
|
defer C.free(cRes.proto_blob)
|
||||||
return err
|
|
||||||
}
|
|
||||||
cpro := cRes.proto
|
|
||||||
blob := C.GoBytes(unsafe.Pointer(cpro.proto_blob), C.int32_t(cpro.proto_size))
|
|
||||||
defer C.free(cpro.proto_blob)
|
|
||||||
return proto.Unmarshal(blob, msg)
|
return proto.Unmarshal(blob, msg)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,6 +35,11 @@ type MarshaledHits struct {
|
|||||||
cMarshaledHits C.CMarshaledHits
|
cMarshaledHits C.CMarshaledHits
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RetrieveResult contains a pointer to the retrieve result in C++ memory
|
||||||
|
type RetrieveResult struct {
|
||||||
|
cRetrieveResult C.CRetrieveResult
|
||||||
|
}
|
||||||
|
|
||||||
func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, numSegments int64) error {
|
func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, numSegments int64) error {
|
||||||
if plan.cSearchPlan == nil {
|
if plan.cSearchPlan == nil {
|
||||||
return errors.New("nil search plan")
|
return errors.New("nil search plan")
|
||||||
|
|||||||
@ -316,10 +316,15 @@ func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, erro
|
|||||||
if s.segmentPtr == nil {
|
if s.segmentPtr == nil {
|
||||||
return nil, errors.New("null seg core pointer")
|
return nil, errors.New("null seg core pointer")
|
||||||
}
|
}
|
||||||
resProto := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, C.uint64_t(plan.Timestamp))
|
|
||||||
|
var retrieveResult RetrieveResult
|
||||||
|
ts := C.uint64_t(plan.Timestamp)
|
||||||
|
status := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, ts, &retrieveResult.cRetrieveResult)
|
||||||
|
if err := HandleCStatus(&status, "Retrieve failed"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
result := new(segcorepb.RetrieveResults)
|
result := new(segcorepb.RetrieveResults)
|
||||||
err := HandleCProtoResult(&resProto, result)
|
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user