mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
optimize search reduce logic (#7066)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
b3f10ae5cc
commit
6c75301c70
@ -75,6 +75,7 @@ struct SearchResult {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
// TODO(gexi): utilize these field
|
// TODO(gexi): utilize these field
|
||||||
|
void* segment_;
|
||||||
std::vector<int64_t> internal_seg_offsets_;
|
std::vector<int64_t> internal_seg_offsets_;
|
||||||
std::vector<int64_t> result_offsets_;
|
std::vector<int64_t> result_offsets_;
|
||||||
std::vector<std::vector<char>> row_data_;
|
std::vector<std::vector<char>> row_data_;
|
||||||
|
|||||||
@ -546,7 +546,7 @@ GetNumOfQueries(const PlaceholderGroup* group) {
|
|||||||
return group->at(0).num_of_queries_;
|
return group->at(0).num_of_queries_;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[maybe_unused]] std::unique_ptr<RetrievePlan>
|
std::unique_ptr<RetrievePlan>
|
||||||
CreateRetrievePlan(const Schema& schema, proto::segcore::RetrieveRequest&& request) {
|
CreateRetrievePlan(const Schema& schema, proto::segcore::RetrieveRequest&& request) {
|
||||||
auto plan = std::make_unique<RetrievePlan>();
|
auto plan = std::make_unique<RetrievePlan>();
|
||||||
plan->ids_ = std::unique_ptr<proto::schema::IDs>(request.release_ids());
|
plan->ids_ = std::unique_ptr<proto::schema::IDs>(request.release_ids());
|
||||||
|
|||||||
@ -79,6 +79,7 @@ SegmentInternalInterface::Search(const query::Plan* plan,
|
|||||||
check_search(plan);
|
check_search(plan);
|
||||||
query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group);
|
query::ExecPlanNodeVisitor visitor(*this, timestamp, placeholder_group);
|
||||||
auto results = visitor.get_moved_result(*plan->plan_node_);
|
auto results = visitor.get_moved_result(*plan->plan_node_);
|
||||||
|
results.segment_ = (void*)this;
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -11,10 +11,12 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <exceptions/EasyAssert.h>
|
#include <exceptions/EasyAssert.h>
|
||||||
#include "segcore/reduce_c.h"
|
|
||||||
|
|
||||||
|
#include "query/Plan.h"
|
||||||
|
#include "segcore/reduce_c.h"
|
||||||
#include "segcore/Reduce.h"
|
#include "segcore/Reduce.h"
|
||||||
#include "segcore/ReduceStructure.h"
|
#include "segcore/ReduceStructure.h"
|
||||||
|
#include "segcore/SegmentInterface.h"
|
||||||
#include "common/Types.h"
|
#include "common/Types.h"
|
||||||
#include "pb/milvus.pb.h"
|
#include "pb/milvus.pb.h"
|
||||||
|
|
||||||
@ -26,7 +28,7 @@ MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, fl
|
|||||||
return status.code();
|
return status.code();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MarshaledHitsPeerGroup {
|
struct MarshaledHitsPerGroup {
|
||||||
std::vector<std::string> hits_;
|
std::vector<std::string> hits_;
|
||||||
std::vector<int64_t> blob_length_;
|
std::vector<int64_t> blob_length_;
|
||||||
};
|
};
|
||||||
@ -41,7 +43,7 @@ struct MarshaledHits {
|
|||||||
return marshaled_hits_.size();
|
return marshaled_hits_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MarshaledHitsPeerGroup> marshaled_hits_;
|
std::vector<MarshaledHitsPerGroup> marshaled_hits_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -53,16 +55,16 @@ DeleteMarshaledHits(CMarshaledHits c_marshaled_hits) {
|
|||||||
void
|
void
|
||||||
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
||||||
std::vector<SearchResult*>& search_results,
|
std::vector<SearchResult*>& search_results,
|
||||||
int64_t query_offset,
|
int64_t query_idx,
|
||||||
bool* is_selected,
|
|
||||||
int64_t topk) {
|
int64_t topk) {
|
||||||
auto num_segments = search_results.size();
|
auto num_segments = search_results.size();
|
||||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||||
std::vector<SearchResultPair> result_pairs;
|
std::vector<SearchResultPair> result_pairs;
|
||||||
|
int64_t query_offset = query_idx * topk;
|
||||||
for (int j = 0; j < num_segments; ++j) {
|
for (int j = 0; j < num_segments; ++j) {
|
||||||
auto distance = search_results[j]->result_distances_[query_offset];
|
|
||||||
auto search_result = search_results[j];
|
auto search_result = search_results[j];
|
||||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||||
|
auto distance = search_result->result_distances_[query_offset];
|
||||||
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
|
result_pairs.push_back(SearchResultPair(distance, search_result, query_offset, j));
|
||||||
}
|
}
|
||||||
int64_t loc_offset = query_offset;
|
int64_t loc_offset = query_offset;
|
||||||
@ -72,24 +74,21 @@ GetResultData(std::vector<std::vector<int64_t>>& search_records,
|
|||||||
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
|
||||||
auto& result_pair = result_pairs[0];
|
auto& result_pair = result_pairs[0];
|
||||||
auto index = result_pair.index_;
|
auto index = result_pair.index_;
|
||||||
is_selected[index] = true;
|
|
||||||
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
result_pair.search_result_->result_offsets_.push_back(loc_offset++);
|
||||||
search_records[index].push_back(result_pair.offset_++);
|
search_records[index].push_back(result_pair.offset_++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
|
ResetSearchResult(std::vector<std::vector<int64_t>>& search_records, std::vector<SearchResult*>& search_results) {
|
||||||
std::vector<SearchResult*>& search_results,
|
|
||||||
bool* is_selected) {
|
|
||||||
auto num_segments = search_results.size();
|
auto num_segments = search_results.size();
|
||||||
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
AssertInfo(num_segments > 0, "num segment must greater than 0");
|
||||||
for (int i = 0; i < num_segments; i++) {
|
for (int i = 0; i < num_segments; i++) {
|
||||||
if (is_selected[i] == false) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto search_result = search_results[i];
|
auto search_result = search_results[i];
|
||||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||||
|
if (search_result->result_offsets_.size() == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<float> result_distances;
|
std::vector<float> result_distances;
|
||||||
std::vector<int64_t> internal_seg_offsets;
|
std::vector<int64_t> internal_seg_offsets;
|
||||||
@ -108,8 +107,9 @@ ResetSearchResult(std::vector<std::vector<int64_t>>& search_records,
|
|||||||
}
|
}
|
||||||
|
|
||||||
CStatus
|
CStatus
|
||||||
ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool* is_selected) {
|
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* c_search_results, int64_t num_segments) {
|
||||||
try {
|
try {
|
||||||
|
auto plan = (milvus::query::Plan*)c_plan;
|
||||||
std::vector<SearchResult*> search_results;
|
std::vector<SearchResult*> search_results;
|
||||||
for (int i = 0; i < num_segments; ++i) {
|
for (int i = 0; i < num_segments; ++i) {
|
||||||
search_results.push_back((SearchResult*)c_search_results[i]);
|
search_results.push_back((SearchResult*)c_search_results[i]);
|
||||||
@ -118,12 +118,17 @@ ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool*
|
|||||||
auto num_queries = search_results[0]->num_queries_;
|
auto num_queries = search_results[0]->num_queries_;
|
||||||
std::vector<std::vector<int64_t>> search_records(num_segments);
|
std::vector<std::vector<int64_t>> search_records(num_segments);
|
||||||
|
|
||||||
int64_t query_offset = 0;
|
for (int i = 0; i < num_queries; ++i) {
|
||||||
for (int j = 0; j < num_queries; ++j) {
|
GetResultData(search_records, search_results, i, topk);
|
||||||
GetResultData(search_records, search_results, query_offset, is_selected, topk);
|
|
||||||
query_offset += topk;
|
|
||||||
}
|
}
|
||||||
ResetSearchResult(search_records, search_results, is_selected);
|
ResetSearchResult(search_records, search_results);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_segments; ++i) {
|
||||||
|
auto search_result = search_results[i];
|
||||||
|
auto segment = (milvus::segcore::SegmentInterface*)(search_result->segment_);
|
||||||
|
segment->FillTargetEntry(plan, *search_result);
|
||||||
|
}
|
||||||
|
|
||||||
auto status = CStatus();
|
auto status = CStatus();
|
||||||
status.error_code = Success;
|
status.error_code = Success;
|
||||||
status.error_msg = "";
|
status.error_msg = "";
|
||||||
@ -137,43 +142,29 @@ ReduceSearchResults(CSearchResult* c_search_results, int64_t num_segments, bool*
|
|||||||
}
|
}
|
||||||
|
|
||||||
CStatus
|
CStatus
|
||||||
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits,
|
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments) {
|
||||||
CPlaceholderGroup* c_placeholder_groups,
|
|
||||||
int64_t num_groups,
|
|
||||||
CSearchResult* c_search_results,
|
|
||||||
bool* is_selected,
|
|
||||||
int64_t num_segments,
|
|
||||||
CSearchPlan c_plan) {
|
|
||||||
try {
|
try {
|
||||||
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
|
auto marshaledHits = std::make_unique<MarshaledHits>(1);
|
||||||
auto topk = GetTopK(c_plan);
|
auto sr = (SearchResult*)c_search_results[0];
|
||||||
std::vector<int64_t> num_queries_peer_group(num_groups);
|
auto topk = sr->topk_;
|
||||||
int64_t total_num_queries = 0;
|
auto num_queries = sr->num_queries_;
|
||||||
for (int i = 0; i < num_groups; i++) {
|
|
||||||
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
|
|
||||||
num_queries_peer_group[i] = num_queries;
|
|
||||||
total_num_queries += num_queries;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> result_distances(total_num_queries * topk);
|
std::vector<float> result_distances(num_queries * topk);
|
||||||
std::vector<int64_t> result_ids(total_num_queries * topk);
|
std::vector<std::vector<char>> row_datas(num_queries * topk);
|
||||||
std::vector<std::vector<char>> row_datas(total_num_queries * topk);
|
|
||||||
std::vector<char> temp_ids;
|
|
||||||
|
|
||||||
std::vector<int64_t> counts(num_segments);
|
std::vector<int64_t> counts(num_segments);
|
||||||
for (int i = 0; i < num_segments; i++) {
|
for (int i = 0; i < num_segments; i++) {
|
||||||
if (is_selected[i] == false) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto search_result = (SearchResult*)c_search_results[i];
|
auto search_result = (SearchResult*)c_search_results[i];
|
||||||
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
|
||||||
auto size = search_result->result_offsets_.size();
|
auto size = search_result->result_offsets_.size();
|
||||||
|
if (size == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int j = 0; j < size; j++) {
|
for (int j = 0; j < size; j++) {
|
||||||
auto loc = search_result->result_offsets_[j];
|
auto loc = search_result->result_offsets_[j];
|
||||||
result_distances[loc] = search_result->result_distances_[j];
|
result_distances[loc] = search_result->result_distances_[j];
|
||||||
row_datas[loc] = search_result->row_data_[j];
|
row_datas[loc] = search_result->row_data_[j];
|
||||||
memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t));
|
|
||||||
}
|
}
|
||||||
counts[i] = size;
|
counts[i] = size;
|
||||||
}
|
}
|
||||||
@ -182,100 +173,35 @@ ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits,
|
|||||||
for (int i = 0; i < num_segments; i++) {
|
for (int i = 0; i < num_segments; i++) {
|
||||||
total_count += counts[i];
|
total_count += counts[i];
|
||||||
}
|
}
|
||||||
AssertInfo(total_count == total_num_queries * topk,
|
AssertInfo(total_count == num_queries * topk, "the reduces result's size less than total_num_queries*topk");
|
||||||
"the reduces result's size less than total_num_queries*topk");
|
|
||||||
|
|
||||||
int64_t last_offset = 0;
|
MarshaledHitsPerGroup& hits_per_group = (*marshaledHits).marshaled_hits_[0];
|
||||||
for (int i = 0; i < num_groups; i++) {
|
hits_per_group.hits_.resize(num_queries);
|
||||||
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
|
hits_per_group.blob_length_.resize(num_queries);
|
||||||
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
|
std::vector<milvus::proto::milvus::Hits> hits(num_queries);
|
||||||
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
|
|
||||||
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int m = 0; m < num_queries_peer_group[i]; m++) {
|
for (int m = 0; m < num_queries; m++) {
|
||||||
for (int n = 0; n < topk; n++) {
|
for (int n = 0; n < topk; n++) {
|
||||||
int64_t result_offset = last_offset + m * topk + n;
|
int64_t result_offset = m * topk + n;
|
||||||
hits[m].add_ids(result_ids[result_offset]);
|
|
||||||
hits[m].add_scores(result_distances[result_offset]);
|
hits[m].add_scores(result_distances[result_offset]);
|
||||||
auto& row_data = row_datas[result_offset];
|
auto& row_data = row_datas[result_offset];
|
||||||
hits[m].add_row_data(row_data.data(), row_data.size());
|
hits[m].add_row_data(row_data.data(), row_data.size());
|
||||||
|
hits[m].add_ids(*(int64_t*)row_data.data());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
last_offset = last_offset + num_queries_peer_group[i] * topk;
|
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
for (int j = 0; j < num_queries; j++) {
|
||||||
auto blob = hits[j].SerializeAsString();
|
auto blob = hits[j].SerializeAsString();
|
||||||
hits_peer_group.hits_[j] = blob;
|
hits_per_group.hits_[j] = blob;
|
||||||
hits_peer_group.blob_length_[j] = blob.size();
|
hits_per_group.blob_length_[j] = blob.size();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto status = CStatus();
|
auto status = CStatus();
|
||||||
status.error_code = Success;
|
status.error_code = Success;
|
||||||
status.error_msg = "";
|
status.error_msg = "";
|
||||||
auto marshled_res = (CMarshaledHits)marshaledHits.release();
|
auto marshaled_res = (CMarshaledHits)marshaledHits.release();
|
||||||
*c_marshaled_hits = marshled_res;
|
*c_marshaled_hits = marshaled_res;
|
||||||
return status;
|
|
||||||
} catch (std::exception& e) {
|
|
||||||
auto status = CStatus();
|
|
||||||
status.error_code = UnexpectedError;
|
|
||||||
status.error_msg = strdup(e.what());
|
|
||||||
*c_marshaled_hits = nullptr;
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CStatus
|
|
||||||
ReorganizeSingleSearchResult(CMarshaledHits* c_marshaled_hits,
|
|
||||||
CPlaceholderGroup* c_placeholder_groups,
|
|
||||||
int64_t num_groups,
|
|
||||||
CSearchResult c_search_result,
|
|
||||||
CSearchPlan c_plan) {
|
|
||||||
try {
|
|
||||||
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
|
|
||||||
auto search_result = (SearchResult*)c_search_result;
|
|
||||||
auto topk = GetTopK(c_plan);
|
|
||||||
std::vector<int64_t> num_queries_peer_group;
|
|
||||||
int64_t total_num_queries = 0;
|
|
||||||
for (int i = 0; i < num_groups; i++) {
|
|
||||||
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
|
|
||||||
num_queries_peer_group.push_back(num_queries);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t last_offset = 0;
|
|
||||||
for (int i = 0; i < num_groups; i++) {
|
|
||||||
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
|
|
||||||
hits_peer_group.hits_.resize(num_queries_peer_group[i]);
|
|
||||||
hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
|
|
||||||
std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (int m = 0; m < num_queries_peer_group[i]; m++) {
|
|
||||||
for (int n = 0; n < topk; n++) {
|
|
||||||
int64_t result_offset = last_offset + m * topk + n;
|
|
||||||
hits[m].add_scores(search_result->result_distances_[result_offset]);
|
|
||||||
auto& row_data = search_result->row_data_[result_offset];
|
|
||||||
hits[m].add_row_data(row_data.data(), row_data.size());
|
|
||||||
int64_t result_id;
|
|
||||||
memcpy(&result_id, row_data.data(), sizeof(int64_t));
|
|
||||||
hits[m].add_ids(result_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
last_offset = last_offset + num_queries_peer_group[i] * topk;
|
|
||||||
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (int j = 0; j < num_queries_peer_group[i]; j++) {
|
|
||||||
auto blob = hits[j].SerializeAsString();
|
|
||||||
hits_peer_group.hits_[j] = blob;
|
|
||||||
hits_peer_group.blob_length_[j] = blob.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto status = CStatus();
|
|
||||||
status.error_code = Success;
|
|
||||||
status.error_msg = "";
|
|
||||||
auto marshled_res = (CMarshaledHits)marshaledHits.release();
|
|
||||||
*c_marshaled_hits = marshled_res;
|
|
||||||
return status;
|
return status;
|
||||||
} catch (std::exception& e) {
|
} catch (std::exception& e) {
|
||||||
auto status = CStatus();
|
auto status = CStatus();
|
||||||
@ -318,14 +244,14 @@ GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) {
|
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index) {
|
||||||
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
||||||
auto& hits = marshaled_hits->marshaled_hits_[group_index].hits_;
|
auto& hits = marshaled_hits->marshaled_hits_[group_index].hits_;
|
||||||
return hits.size();
|
return hits.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) {
|
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query) {
|
||||||
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
auto marshaled_hits = (MarshaledHits*)c_marshaled_hits;
|
||||||
auto& blob_lens = marshaled_hits->marshaled_hits_[group_index].blob_length_;
|
auto& blob_lens = marshaled_hits->marshaled_hits_[group_index].blob_length_;
|
||||||
for (int i = 0; i < blob_lens.size(); i++) {
|
for (int i = 0; i < blob_lens.size(); i++) {
|
||||||
|
|||||||
@ -15,6 +15,7 @@ extern "C" {
|
|||||||
|
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
#include "segcore/plan_c.h"
|
||||||
#include "segcore/segment_c.h"
|
#include "segcore/segment_c.h"
|
||||||
#include "common/type_c.h"
|
#include "common/type_c.h"
|
||||||
|
|
||||||
@ -27,23 +28,10 @@ int
|
|||||||
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
|
MergeInto(int64_t num_queries, int64_t topk, float* distances, int64_t* uids, float* new_distances, int64_t* new_uids);
|
||||||
|
|
||||||
CStatus
|
CStatus
|
||||||
ReduceSearchResults(CSearchResult* search_results, int64_t num_segments, bool* is_selected);
|
ReduceSearchResultsAndFillData(CSearchPlan c_plan, CSearchResult* search_results, int64_t num_segments);
|
||||||
|
|
||||||
CStatus
|
CStatus
|
||||||
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits,
|
ReorganizeSearchResults(CMarshaledHits* c_marshaled_hits, CSearchResult* c_search_results, int64_t num_segments);
|
||||||
CPlaceholderGroup* c_placeholder_groups,
|
|
||||||
int64_t num_groups,
|
|
||||||
CSearchResult* c_search_results,
|
|
||||||
bool* is_selected,
|
|
||||||
int64_t num_segments,
|
|
||||||
CSearchPlan c_plan);
|
|
||||||
|
|
||||||
CStatus
|
|
||||||
ReorganizeSingleSearchResult(CMarshaledHits* c_marshaled_hits,
|
|
||||||
CPlaceholderGroup* c_placeholder_groups,
|
|
||||||
int64_t num_groups,
|
|
||||||
CSearchResult c_search_result,
|
|
||||||
CSearchPlan c_plan);
|
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
|
GetHitsBlobSize(CMarshaledHits c_marshaled_hits);
|
||||||
@ -52,10 +40,10 @@ void
|
|||||||
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits);
|
GetHitsBlob(CMarshaledHits c_marshaled_hits, const void* hits);
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetNumQueriesPeerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index);
|
GetNumQueriesPerGroup(CMarshaledHits c_marshaled_hits, int64_t group_index);
|
||||||
|
|
||||||
void
|
void
|
||||||
GetHitSizePeerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query);
|
GetHitSizePerQueries(CMarshaledHits c_marshaled_hits, int64_t group_index, int64_t* hit_size_peer_query);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|||||||
@ -88,20 +88,6 @@ Search(CSegmentInterface c_segment,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CStatus
|
|
||||||
FillTargetEntry(CSegmentInterface c_segment, CSearchPlan c_plan, CSearchResult c_result) {
|
|
||||||
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
|
|
||||||
auto plan = (milvus::query::Plan*)c_plan;
|
|
||||||
auto result = (milvus::SearchResult*)c_result;
|
|
||||||
|
|
||||||
try {
|
|
||||||
segment->FillTargetEntry(plan, *result);
|
|
||||||
return milvus::SuccessCStatus();
|
|
||||||
} catch (std::exception& e) {
|
|
||||||
return milvus::FailureCStatus(UnexpectedError, e.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetMemoryUsageInBytes(CSegmentInterface c_segment) {
|
GetMemoryUsageInBytes(CSegmentInterface c_segment) {
|
||||||
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
|
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
|
||||||
|
|||||||
@ -46,9 +46,6 @@ Search(CSegmentInterface c_segment,
|
|||||||
CProtoResult
|
CProtoResult
|
||||||
GetEntityByIds(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp);
|
GetEntityByIds(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp);
|
||||||
|
|
||||||
CStatus
|
|
||||||
FillTargetEntry(CSegmentInterface c_segment, CSearchPlan c_plan, CSearchResult result);
|
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetMemoryUsageInBytes(CSegmentInterface c_segment);
|
GetMemoryUsageInBytes(CSegmentInterface c_segment);
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,6 @@ namespace chrono = std::chrono;
|
|||||||
|
|
||||||
using namespace milvus;
|
using namespace milvus;
|
||||||
using namespace milvus::segcore;
|
using namespace milvus::segcore;
|
||||||
// using namespace milvus::proto;
|
|
||||||
using namespace milvus::knowhere;
|
using namespace milvus::knowhere;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -203,7 +202,7 @@ TEST(CApiTest, InsertTest) {
|
|||||||
|
|
||||||
int N = 10000;
|
int N = 10000;
|
||||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
PreInsert(segment, N, &offset);
|
PreInsert(segment, N, &offset);
|
||||||
@ -237,7 +236,7 @@ TEST(CApiTest, SearchTest) {
|
|||||||
|
|
||||||
int N = 10000;
|
int N = 10000;
|
||||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
PreInsert(segment, N, &offset);
|
PreInsert(segment, N, &offset);
|
||||||
@ -294,7 +293,7 @@ TEST(CApiTest, SearchTestWithExpr) {
|
|||||||
|
|
||||||
int N = 10000;
|
int N = 10000;
|
||||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
PreInsert(segment, N, &offset);
|
PreInsert(segment, N, &offset);
|
||||||
@ -350,7 +349,7 @@ TEST(CApiTest, GetMemoryUsageInBytesTest) {
|
|||||||
|
|
||||||
int N = 10000;
|
int N = 10000;
|
||||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
PreInsert(segment, N, &offset);
|
PreInsert(segment, N, &offset);
|
||||||
@ -392,7 +391,7 @@ TEST(CApiTest, GetRowCountTest) {
|
|||||||
|
|
||||||
int N = 10000;
|
int N = 10000;
|
||||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
PreInsert(segment, N, &offset);
|
PreInsert(segment, N, &offset);
|
||||||
@ -454,7 +453,7 @@ TEST(CApiTest, Reduce) {
|
|||||||
|
|
||||||
int N = 10000;
|
int N = 10000;
|
||||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
PreInsert(segment, N, &offset);
|
PreInsert(segment, N, &offset);
|
||||||
@ -503,14 +502,10 @@ TEST(CApiTest, Reduce) {
|
|||||||
results.push_back(res1);
|
results.push_back(res1);
|
||||||
results.push_back(res2);
|
results.push_back(res2);
|
||||||
|
|
||||||
bool is_selected[2] = {false, false};
|
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||||
status = ReduceSearchResults(results.data(), 2, is_selected);
|
|
||||||
assert(status.error_code == Success);
|
assert(status.error_code == Success);
|
||||||
FillTargetEntry(segment, plan, res1);
|
|
||||||
FillTargetEntry(segment, plan, res2);
|
|
||||||
void* reorganize_search_result = nullptr;
|
void* reorganize_search_result = nullptr;
|
||||||
status = ReorganizeSearchResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(),
|
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
|
||||||
is_selected, 2, plan);
|
|
||||||
assert(status.error_code == Success);
|
assert(status.error_code == Success);
|
||||||
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
||||||
assert(hits_blob_size > 0);
|
assert(hits_blob_size > 0);
|
||||||
@ -518,12 +513,12 @@ TEST(CApiTest, Reduce) {
|
|||||||
hits_blob.resize(hits_blob_size);
|
hits_blob.resize(hits_blob_size);
|
||||||
GetHitsBlob(reorganize_search_result, hits_blob.data());
|
GetHitsBlob(reorganize_search_result, hits_blob.data());
|
||||||
assert(hits_blob.data() != nullptr);
|
assert(hits_blob.data() != nullptr);
|
||||||
auto num_queries_group = GetNumQueriesPeerGroup(reorganize_search_result, 0);
|
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
|
||||||
assert(num_queries_group == 10);
|
assert(num_queries_group == num_queries);
|
||||||
std::vector<int64_t> hit_size_peer_query;
|
std::vector<int64_t> hit_size_per_query;
|
||||||
hit_size_peer_query.resize(num_queries_group);
|
hit_size_per_query.resize(num_queries_group);
|
||||||
GetHitSizePeerQueries(reorganize_search_result, 0, hit_size_peer_query.data());
|
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
|
||||||
assert(hit_size_peer_query[0] > 0);
|
assert(hit_size_per_query[0] > 0);
|
||||||
|
|
||||||
DeleteSearchPlan(plan);
|
DeleteSearchPlan(plan);
|
||||||
DeletePlaceholderGroup(placeholderGroup);
|
DeletePlaceholderGroup(placeholderGroup);
|
||||||
@ -540,7 +535,7 @@ TEST(CApiTest, ReduceSearchWithExpr) {
|
|||||||
|
|
||||||
int N = 10000;
|
int N = 10000;
|
||||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||||
auto line_sizeof = (sizeof(int) + sizeof(float) * 16);
|
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
|
||||||
|
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
PreInsert(segment, N, &offset);
|
PreInsert(segment, N, &offset);
|
||||||
@ -584,14 +579,10 @@ TEST(CApiTest, ReduceSearchWithExpr) {
|
|||||||
results.push_back(res1);
|
results.push_back(res1);
|
||||||
results.push_back(res2);
|
results.push_back(res2);
|
||||||
|
|
||||||
bool is_selected[2] = {false, false};
|
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||||
status = ReduceSearchResults(results.data(), 2, is_selected);
|
|
||||||
assert(status.error_code == Success);
|
assert(status.error_code == Success);
|
||||||
FillTargetEntry(segment, plan, res1);
|
|
||||||
FillTargetEntry(segment, plan, res2);
|
|
||||||
void* reorganize_search_result = nullptr;
|
void* reorganize_search_result = nullptr;
|
||||||
status = ReorganizeSearchResults(&reorganize_search_result, placeholderGroups.data(), 1, results.data(),
|
status = ReorganizeSearchResults(&reorganize_search_result, results.data(), results.size());
|
||||||
is_selected, 2, plan);
|
|
||||||
assert(status.error_code == Success);
|
assert(status.error_code == Success);
|
||||||
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
auto hits_blob_size = GetHitsBlobSize(reorganize_search_result);
|
||||||
assert(hits_blob_size > 0);
|
assert(hits_blob_size > 0);
|
||||||
@ -599,12 +590,12 @@ TEST(CApiTest, ReduceSearchWithExpr) {
|
|||||||
hits_blob.resize(hits_blob_size);
|
hits_blob.resize(hits_blob_size);
|
||||||
GetHitsBlob(reorganize_search_result, hits_blob.data());
|
GetHitsBlob(reorganize_search_result, hits_blob.data());
|
||||||
assert(hits_blob.data() != nullptr);
|
assert(hits_blob.data() != nullptr);
|
||||||
auto num_queries_group = GetNumQueriesPeerGroup(reorganize_search_result, 0);
|
auto num_queries_group = GetNumQueriesPerGroup(reorganize_search_result, 0);
|
||||||
assert(num_queries_group == 10);
|
assert(num_queries_group == num_queries);
|
||||||
std::vector<int64_t> hit_size_peer_query;
|
std::vector<int64_t> hit_size_per_query;
|
||||||
hit_size_peer_query.resize(num_queries_group);
|
hit_size_per_query.resize(num_queries_group);
|
||||||
GetHitSizePeerQueries(reorganize_search_result, 0, hit_size_peer_query.data());
|
GetHitSizePerQueries(reorganize_search_result, 0, hit_size_per_query.data());
|
||||||
assert(hit_size_peer_query[0] > 0);
|
assert(hit_size_per_query[0] > 0);
|
||||||
|
|
||||||
DeleteSearchPlan(plan);
|
DeleteSearchPlan(plan);
|
||||||
DeletePlaceholderGroup(placeholderGroup);
|
DeletePlaceholderGroup(placeholderGroup);
|
||||||
@ -1921,10 +1912,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) {
|
|||||||
|
|
||||||
std::vector<CSearchResult> results;
|
std::vector<CSearchResult> results;
|
||||||
results.push_back(c_search_result_on_bigIndex);
|
results.push_back(c_search_result_on_bigIndex);
|
||||||
bool is_selected[1] = {false};
|
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||||
status = ReduceSearchResults(results.data(), 1, is_selected);
|
|
||||||
assert(status.error_code == Success);
|
assert(status.error_code == Success);
|
||||||
FillTargetEntry(segment, plan, c_search_result_on_bigIndex);
|
|
||||||
|
|
||||||
auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex);
|
auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex);
|
||||||
for (int i = 0; i < num_queries; ++i) {
|
for (int i = 0; i < num_queries; ++i) {
|
||||||
@ -2073,10 +2062,8 @@ vector_anns: <
|
|||||||
|
|
||||||
std::vector<CSearchResult> results;
|
std::vector<CSearchResult> results;
|
||||||
results.push_back(c_search_result_on_bigIndex);
|
results.push_back(c_search_result_on_bigIndex);
|
||||||
bool is_selected[1] = {false};
|
status = ReduceSearchResultsAndFillData(plan, results.data(), results.size());
|
||||||
status = ReduceSearchResults(results.data(), 1, is_selected);
|
|
||||||
assert(status.error_code == Success);
|
assert(status.error_code == Success);
|
||||||
FillTargetEntry(segment, plan, c_search_result_on_bigIndex);
|
|
||||||
|
|
||||||
auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex);
|
auto search_result_on_bigIndex = (*(SearchResult*)c_search_result_on_bigIndex);
|
||||||
for (int i = 0; i < num_queries; ++i) {
|
for (int i = 0; i < num_queries; ++i) {
|
||||||
|
|||||||
@ -843,7 +843,6 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
searchResults := make([]*SearchResult, 0)
|
searchResults := make([]*SearchResult, 0)
|
||||||
matchedSegments := make([]*Segment, 0)
|
|
||||||
sealedSegmentSearched := make([]UniqueID, 0)
|
sealedSegmentSearched := make([]UniqueID, 0)
|
||||||
|
|
||||||
// historical search
|
// historical search
|
||||||
@ -853,7 +852,6 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
searchResults = append(searchResults, hisSearchResults...)
|
searchResults = append(searchResults, hisSearchResults...)
|
||||||
matchedSegments = append(matchedSegments, hisSegmentResults...)
|
|
||||||
for _, seg := range hisSegmentResults {
|
for _, seg := range hisSegmentResults {
|
||||||
sealedSegmentSearched = append(sealedSegmentSearched, seg.segmentID)
|
sealedSegmentSearched = append(sealedSegmentSearched, seg.segmentID)
|
||||||
}
|
}
|
||||||
@ -863,14 +861,12 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||||||
var err2 error
|
var err2 error
|
||||||
for _, channel := range collection.getVChannels() {
|
for _, channel := range collection.getVChannels() {
|
||||||
var strSearchResults []*SearchResult
|
var strSearchResults []*SearchResult
|
||||||
var strSegmentResults []*Segment
|
strSearchResults, err2 = q.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
|
||||||
strSearchResults, strSegmentResults, err2 = q.streaming.search(searchRequests, collectionID, searchMsg.PartitionIDs, channel, plan, travelTimestamp)
|
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
log.Warn(err2.Error())
|
log.Warn(err2.Error())
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
searchResults = append(searchResults, strSearchResults...)
|
searchResults = append(searchResults, strSearchResults...)
|
||||||
matchedSegments = append(matchedSegments, strSegmentResults...)
|
|
||||||
}
|
}
|
||||||
tr.Record("streaming search done")
|
tr.Record("streaming search done")
|
||||||
|
|
||||||
@ -939,38 +935,19 @@ func (q *queryCollection) search(msg queryMsg) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inReduced := make([]bool, len(searchResults))
|
|
||||||
numSegment := int64(len(searchResults))
|
numSegment := int64(len(searchResults))
|
||||||
var marshaledHits *MarshaledHits = nil
|
var marshaledHits *MarshaledHits = nil
|
||||||
if numSegment == 1 {
|
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
|
||||||
inReduced[0] = true
|
|
||||||
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
|
|
||||||
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
marshaledHits, err = reorganizeSingleSearchResult(plan, searchRequests, searchResults[0])
|
|
||||||
sp.LogFields(oplog.String("statistical time", "reorganizeSingleSearchResult end"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = reduceSearchResults(searchResults, numSegment, inReduced)
|
|
||||||
sp.LogFields(oplog.String("statistical time", "reduceSearchResults end"))
|
sp.LogFields(oplog.String("statistical time", "reduceSearchResults end"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
|
marshaledHits, err = reorganizeSearchResults(searchResults, numSegment)
|
||||||
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
marshaledHits, err = reorganizeSearchResults(plan, searchRequests, searchResults, numSegment, inReduced)
|
|
||||||
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
|
sp.LogFields(oplog.String("statistical time", "reorganizeSearchResults end"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||||
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
|
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -23,10 +23,7 @@ import "C"
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type SearchResult struct {
|
type SearchResult struct {
|
||||||
@ -37,17 +34,15 @@ type MarshaledHits struct {
|
|||||||
cMarshaledHits C.CMarshaledHits
|
cMarshaledHits C.CMarshaledHits
|
||||||
}
|
}
|
||||||
|
|
||||||
func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inReduced []bool) error {
|
func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, numSegments int64) error {
|
||||||
cSearchResults := make([]C.CSearchResult, 0)
|
cSearchResults := make([]C.CSearchResult, 0)
|
||||||
for _, res := range searchResults {
|
for _, res := range searchResults {
|
||||||
cSearchResults = append(cSearchResults, res.cSearchResult)
|
cSearchResults = append(cSearchResults, res.cSearchResult)
|
||||||
}
|
}
|
||||||
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
|
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
|
||||||
cNumSegments := C.long(numSegments)
|
cNumSegments := C.long(numSegments)
|
||||||
cInReduced := (*C.bool)(&inReduced[0])
|
|
||||||
|
|
||||||
status := C.ReduceSearchResults(cSearchResultPtr, cNumSegments, cInReduced)
|
|
||||||
|
|
||||||
|
status := C.ReduceSearchResultsAndFillData(plan.cSearchPlan, cSearchResultPtr, cNumSegments)
|
||||||
errorCode := status.error_code
|
errorCode := status.error_code
|
||||||
|
|
||||||
if errorCode != 0 {
|
if errorCode != 0 {
|
||||||
@ -58,33 +53,7 @@ func reduceSearchResults(searchResults []*SearchResult, numSegments int64, inRed
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fillTargetEntry(plan *SearchPlan, searchResults []*SearchResult, matchedSegments []*Segment, inReduced []bool) error {
|
func reorganizeSearchResults(searchResults []*SearchResult, numSegments int64) (*MarshaledHits, error) {
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
//fmt.Println(inReduced)
|
|
||||||
for i := range inReduced {
|
|
||||||
if inReduced[i] {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(i int) {
|
|
||||||
err := matchedSegments[i].fillTargetEntry(plan, searchResults[i])
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(err.Error())
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func reorganizeSearchResults(plan *SearchPlan, searchRequests []*searchRequest, searchResults []*SearchResult, numSegments int64, inReduced []bool) (*MarshaledHits, error) {
|
|
||||||
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
|
|
||||||
for _, pg := range searchRequests {
|
|
||||||
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
|
|
||||||
}
|
|
||||||
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
|
||||||
var cNumGroup = (C.long)(len(searchRequests))
|
|
||||||
|
|
||||||
cSearchResults := make([]C.CSearchResult, 0)
|
cSearchResults := make([]C.CSearchResult, 0)
|
||||||
for _, res := range searchResults {
|
for _, res := range searchResults {
|
||||||
cSearchResults = append(cSearchResults, res.cSearchResult)
|
cSearchResults = append(cSearchResults, res.cSearchResult)
|
||||||
@ -92,32 +61,9 @@ func reorganizeSearchResults(plan *SearchPlan, searchRequests []*searchRequest,
|
|||||||
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
|
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
|
||||||
|
|
||||||
var cNumSegments = C.long(numSegments)
|
var cNumSegments = C.long(numSegments)
|
||||||
var cInReduced = (*C.bool)(&inReduced[0])
|
|
||||||
var cMarshaledHits C.CMarshaledHits
|
var cMarshaledHits C.CMarshaledHits
|
||||||
|
|
||||||
status := C.ReorganizeSearchResults(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResultPtr, cInReduced, cNumSegments, plan.cSearchPlan)
|
status := C.ReorganizeSearchResults(&cMarshaledHits, cSearchResultPtr, cNumSegments)
|
||||||
errorCode := status.error_code
|
|
||||||
|
|
||||||
if errorCode != 0 {
|
|
||||||
errorMsg := C.GoString(status.error_msg)
|
|
||||||
defer C.free(unsafe.Pointer(status.error_msg))
|
|
||||||
return nil, errors.New("reorganizeSearchResults failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
|
||||||
}
|
|
||||||
return &MarshaledHits{cMarshaledHits: cMarshaledHits}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func reorganizeSingleSearchResult(plan *SearchPlan, placeholderGroups []*searchRequest, searchResult *SearchResult) (*MarshaledHits, error) {
|
|
||||||
cPlaceholderGroups := make([]C.CPlaceholderGroup, 0)
|
|
||||||
for _, pg := range placeholderGroups {
|
|
||||||
cPlaceholderGroups = append(cPlaceholderGroups, (*pg).cPlaceholderGroup)
|
|
||||||
}
|
|
||||||
var cPlaceHolderGroupPtr = (*C.CPlaceholderGroup)(&cPlaceholderGroups[0])
|
|
||||||
var cNumGroup = (C.long)(len(placeholderGroups))
|
|
||||||
|
|
||||||
cSearchResult := searchResult.cSearchResult
|
|
||||||
var cMarshaledHits C.CMarshaledHits
|
|
||||||
|
|
||||||
status := C.ReorganizeSingleSearchResult(&cMarshaledHits, cPlaceHolderGroupPtr, cNumGroup, cSearchResult, plan.cSearchPlan)
|
|
||||||
errorCode := status.error_code
|
errorCode := status.error_code
|
||||||
|
|
||||||
if errorCode != 0 {
|
if errorCode != 0 {
|
||||||
@ -143,10 +89,10 @@ func (mh *MarshaledHits) getHitsBlob() ([]byte, error) {
|
|||||||
|
|
||||||
func (mh *MarshaledHits) hitBlobSizeInGroup(groupOffset int64) ([]int64, error) {
|
func (mh *MarshaledHits) hitBlobSizeInGroup(groupOffset int64) ([]int64, error) {
|
||||||
cGroupOffset := (C.long)(groupOffset)
|
cGroupOffset := (C.long)(groupOffset)
|
||||||
numQueries := C.GetNumQueriesPeerGroup(mh.cMarshaledHits, cGroupOffset)
|
numQueries := C.GetNumQueriesPerGroup(mh.cMarshaledHits, cGroupOffset)
|
||||||
result := make([]int64, int64(numQueries))
|
result := make([]int64, int64(numQueries))
|
||||||
cResult := (*C.long)(&result[0])
|
cResult := (*C.long)(&result[0])
|
||||||
C.GetHitSizePeerQueries(mh.cMarshaledHits, cGroupOffset, cResult)
|
C.GetHitSizePerQueries(mh.cMarshaledHits, cGroupOffset, cResult)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -71,19 +71,14 @@ func TestReduce_AllFunc(t *testing.T) {
|
|||||||
placeholderGroups = append(placeholderGroups, holder)
|
placeholderGroups = append(placeholderGroups, holder)
|
||||||
|
|
||||||
searchResults := make([]*SearchResult, 0)
|
searchResults := make([]*SearchResult, 0)
|
||||||
matchedSegment := make([]*Segment, 0)
|
|
||||||
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0})
|
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{0})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
searchResults = append(searchResults, searchResult)
|
searchResults = append(searchResults, searchResult)
|
||||||
matchedSegment = append(matchedSegment, segment)
|
|
||||||
|
|
||||||
testReduce := make([]bool, len(searchResults))
|
err = reduceSearchResultsAndFillData(plan, searchResults, 1)
|
||||||
err = reduceSearchResults(searchResults, 1, testReduce)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
err = fillTargetEntry(plan, searchResults, matchedSegment, testReduce)
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
marshaledHits, err := reorganizeSearchResults(plan, placeholderGroups, searchResults, 1, testReduce)
|
marshaledHits, err := reorganizeSearchResults(searchResults, 1)
|
||||||
assert.NotNil(t, marshaledHits)
|
assert.NotNil(t, marshaledHits)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
|||||||
@ -316,26 +316,6 @@ func (s *Segment) getEntityByIds(plan *RetrievePlan) (*segcorepb.RetrieveResults
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Segment) fillTargetEntry(plan *SearchPlan, result *SearchResult) error {
|
|
||||||
s.segPtrMu.RLock()
|
|
||||||
defer s.segPtrMu.RUnlock()
|
|
||||||
if s.segmentPtr == nil {
|
|
||||||
return errors.New("null seg core pointer")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("segment fill target entry, ", zap.Int64("segment ID = ", s.segmentID))
|
|
||||||
var status = C.FillTargetEntry(s.segmentPtr, plan.cSearchPlan, result.cSearchResult)
|
|
||||||
errorCode := status.error_code
|
|
||||||
|
|
||||||
if errorCode != 0 {
|
|
||||||
errorMsg := C.GoString(status.error_msg)
|
|
||||||
defer C.free(unsafe.Pointer(status.error_msg))
|
|
||||||
return errors.New("FillTargetEntry failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//-------------------------------------------------------------------------------------- index info interface
|
//-------------------------------------------------------------------------------------- index info interface
|
||||||
func (s *Segment) setIndexName(fieldID int64, name string) error {
|
func (s *Segment) setIndexName(fieldID int64, name string) error {
|
||||||
s.paramMutex.Lock()
|
s.paramMutex.Lock()
|
||||||
|
|||||||
@ -435,22 +435,15 @@ func TestSegment_segmentSearch(t *testing.T) {
|
|||||||
placeholderGroups = append(placeholderGroups, holder)
|
placeholderGroups = append(placeholderGroups, holder)
|
||||||
|
|
||||||
searchResults := make([]*SearchResult, 0)
|
searchResults := make([]*SearchResult, 0)
|
||||||
matchedSegments := make([]*Segment, 0)
|
|
||||||
|
|
||||||
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{travelTimestamp})
|
searchResult, err := segment.search(plan, placeholderGroups, []Timestamp{travelTimestamp})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
searchResults = append(searchResults, searchResult)
|
searchResults = append(searchResults, searchResult)
|
||||||
matchedSegments = append(matchedSegments, segment)
|
|
||||||
|
|
||||||
///////////////////////////////////
|
///////////////////////////////////
|
||||||
inReduced := make([]bool, len(searchResults))
|
|
||||||
numSegment := int64(len(searchResults))
|
numSegment := int64(len(searchResults))
|
||||||
err2 := reduceSearchResults(searchResults, numSegment, inReduced)
|
err = reduceSearchResultsAndFillData(plan, searchResults, numSegment)
|
||||||
assert.NoError(t, err2)
|
|
||||||
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
marshaledHits, err := reorganizeSearchResults(plan, placeholderGroups, searchResults, numSegment, inReduced)
|
marshaledHits, err := reorganizeSearchResults(searchResults, numSegment)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
hitsBlob, err := marshaledHits.getHitsBlob()
|
hitsBlob, err := marshaledHits.getHitsBlob()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@ -61,15 +61,10 @@ func (s *streaming) close() {
|
|||||||
s.replica.freeAll()
|
s.replica.freeAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *streaming) search(searchReqs []*searchRequest,
|
func (s *streaming) search(searchReqs []*searchRequest, collID UniqueID, partIDs []UniqueID, vChannel Channel,
|
||||||
collID UniqueID,
|
plan *SearchPlan, searchTs Timestamp) ([]*SearchResult, error) {
|
||||||
partIDs []UniqueID,
|
|
||||||
vChannel Channel,
|
|
||||||
plan *SearchPlan,
|
|
||||||
searchTs Timestamp) ([]*SearchResult, []*Segment, error) {
|
|
||||||
|
|
||||||
searchResults := make([]*SearchResult, 0)
|
searchResults := make([]*SearchResult, 0)
|
||||||
segmentResults := make([]*Segment, 0)
|
|
||||||
|
|
||||||
// get streaming partition ids
|
// get streaming partition ids
|
||||||
var searchPartIDs []UniqueID
|
var searchPartIDs []UniqueID
|
||||||
@ -77,10 +72,10 @@ func (s *streaming) search(searchReqs []*searchRequest,
|
|||||||
strPartIDs, err := s.replica.getPartitionIDs(collID)
|
strPartIDs, err := s.replica.getPartitionIDs(collID)
|
||||||
if len(strPartIDs) == 0 {
|
if len(strPartIDs) == 0 {
|
||||||
// no partitions in collection, do empty search
|
// no partitions in collection, do empty search
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return searchResults, segmentResults, err
|
return searchResults, err
|
||||||
}
|
}
|
||||||
log.Debug("no partition specified, search all partitions",
|
log.Debug("no partition specified, search all partitions",
|
||||||
zap.Any("collectionID", collID),
|
zap.Any("collectionID", collID),
|
||||||
@ -104,22 +99,20 @@ func (s *streaming) search(searchReqs []*searchRequest,
|
|||||||
|
|
||||||
col, err := s.replica.getCollectionByID(collID)
|
col, err := s.replica.getCollectionByID(collID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// all partitions have been released
|
// all partitions have been released
|
||||||
if len(searchPartIDs) == 0 && col.getLoadType() == loadTypePartition {
|
if len(searchPartIDs) == 0 && col.getLoadType() == loadTypePartition {
|
||||||
return nil, nil, errors.New("partitions have been released , collectionID = " +
|
err = errors.New("partitions have been released , collectionID = " + fmt.Sprintln(collID) + "target partitionIDs = " + fmt.Sprintln(partIDs))
|
||||||
fmt.Sprintln(collID) +
|
return nil, err
|
||||||
"target partitionIDs = " +
|
|
||||||
fmt.Sprintln(partIDs))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(searchPartIDs) == 0 && col.getLoadType() == loadTypeCollection {
|
if len(searchPartIDs) == 0 && col.getLoadType() == loadTypeCollection {
|
||||||
if err = col.checkReleasedPartitions(partIDs); err != nil {
|
if err = col.checkReleasedPartitions(partIDs); err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("doing search in streaming",
|
log.Debug("doing search in streaming",
|
||||||
@ -144,13 +137,13 @@ func (s *streaming) search(searchReqs []*searchRequest,
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err.Error())
|
log.Warn(err.Error())
|
||||||
return searchResults, segmentResults, err
|
return searchResults, err
|
||||||
}
|
}
|
||||||
for _, segID := range segIDs {
|
for _, segID := range segIDs {
|
||||||
seg, err := s.replica.getSegmentByID(segID)
|
seg, err := s.replica.getSegmentByID(segID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err.Error())
|
log.Warn(err.Error())
|
||||||
return searchResults, segmentResults, err
|
return searchResults, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TSafe less than searchTs means this vChannel is not available
|
// TSafe less than searchTs means this vChannel is not available
|
||||||
@ -175,12 +168,11 @@ func (s *streaming) search(searchReqs []*searchRequest,
|
|||||||
|
|
||||||
searchResult, err := seg.search(plan, searchReqs, []Timestamp{searchTs})
|
searchResult, err := seg.search(plan, searchReqs, []Timestamp{searchTs})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return searchResults, segmentResults, err
|
return searchResults, err
|
||||||
}
|
}
|
||||||
searchResults = append(searchResults, searchResult)
|
searchResults = append(searchResults, searchResult)
|
||||||
segmentResults = append(segmentResults, seg)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return searchResults, segmentResults, nil
|
return searchResults, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user