Optimize search performance

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2021-04-19 19:30:36 +08:00 committed by yefu.chen
parent b79a408491
commit 1165db75f6
4 changed files with 69 additions and 28 deletions

View File

@ -65,6 +65,9 @@ set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
set(FETCHCONTENT_QUIET OFF) set(FETCHCONTENT_QUIET OFF)
include( ThirdPartyPackages ) include( ThirdPartyPackages )
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
if (OPENMP_FOUND)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
# **************************** Compiler arguments **************************** # **************************** Compiler arguments ****************************
message( STATUS "Building Milvus CPU version" ) message( STATUS "Building Milvus CPU version" )

View File

@ -33,7 +33,7 @@ add_library(milvus_segcore SHARED
) )
target_link_libraries(milvus_segcore target_link_libraries(milvus_segcore
tbb milvus_utils pthread knowhere log milvus_proto tbb milvus_utils pthread knowhere log milvus_proto ${OpenMP_CXX_FLAGS}
dl backtrace dl backtrace
milvus_common milvus_common
milvus_query milvus_query

View File

@ -172,19 +172,20 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
try { try {
auto marshaledHits = std::make_unique<MarshaledHits>(num_groups); auto marshaledHits = std::make_unique<MarshaledHits>(num_groups);
auto topk = GetTopK(c_plan); auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group; std::vector<int64_t> num_queries_peer_group(num_groups);
int64_t total_num_queries = 0; int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) { for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries); num_queries_peer_group[i] = num_queries;
total_num_queries += num_queries; total_num_queries += num_queries;
} }
std::vector<float> result_distances(total_num_queries * topk); std::vector<float> result_distances(total_num_queries * topk);
std::vector<int64_t> result_ids(total_num_queries * topk); std::vector<int64_t> result_ids(total_num_queries * topk);
std::vector<std::vector<char>> row_datas(total_num_queries * topk); std::vector<std::vector<char>> row_datas(total_num_queries * topk);
std::vector<char> temp_ids;
int64_t count = 0; std::vector<int64_t> counts(num_segments);
for (int i = 0; i < num_segments; i++) { for (int i = 0; i < num_segments; i++) {
if (is_selected[i] == false) { if (is_selected[i] == false) {
continue; continue;
@ -192,30 +193,46 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits,
auto search_result = (SearchResult*)c_search_results[i]; auto search_result = (SearchResult*)c_search_results[i];
AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); AssertInfo(search_result != nullptr, "search result must not equal to nullptr");
auto size = search_result->result_offsets_.size(); auto size = search_result->result_offsets_.size();
#pragma omp parallel for
for (int j = 0; j < size; j++) { for (int j = 0; j < size; j++) {
auto loc = search_result->result_offsets_[j]; auto loc = search_result->result_offsets_[j];
result_distances[loc] = search_result->result_distances_[j]; result_distances[loc] = search_result->result_distances_[j];
row_datas[loc] = search_result->row_data_[j]; row_datas[loc] = search_result->row_data_[j];
memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t)); memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t));
} }
count += size; counts[i] = size;
} }
AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk");
int64_t fill_hit_offset = 0; int64_t total_count = 0;
for (int i = 0; i < num_segments; i++) {
total_count += counts[i];
}
AssertInfo(total_count == total_num_queries * topk,
"the reduces result's size less than total_num_queries*topk");
int64_t last_offset = 0;
for (int i = 0; i < num_groups; i++) { for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries_peer_group[i]; j++) { hits_peer_group.hits_.resize(num_queries_peer_group[i]);
milvus::proto::milvus::Hits hits; hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
for (int k = 0; k < topk; k++, fill_hit_offset++) { std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
hits.add_ids(result_ids[fill_hit_offset]); #pragma omp parallel for
hits.add_scores(result_distances[fill_hit_offset]); for (int m = 0; m < num_queries_peer_group[i]; m++) {
auto& row_data = row_datas[fill_hit_offset]; for (int n = 0; n < topk; n++) {
hits.add_row_data(row_data.data(), row_data.size()); int64_t result_offset = last_offset + m * topk + n;
hits[m].add_ids(result_ids[result_offset]);
hits[m].add_scores(result_distances[result_offset]);
auto& row_data = row_datas[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
} }
auto blob = hits.SerializeAsString(); }
hits_peer_group.hits_.push_back(blob); last_offset = last_offset + num_queries_peer_group[i] * topk;
hits_peer_group.blob_length_.push_back(blob.size());
#pragma omp parallel for
for (int j = 0; j < num_queries_peer_group[i]; j++) {
auto blob = hits[j].SerializeAsString();
hits_peer_group.hits_[j] = blob;
hits_peer_group.blob_length_[j] = blob.size();
} }
} }
@ -245,27 +262,37 @@ ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits,
auto search_result = (SearchResult*)c_search_result; auto search_result = (SearchResult*)c_search_result;
auto topk = GetTopK(c_plan); auto topk = GetTopK(c_plan);
std::vector<int64_t> num_queries_peer_group; std::vector<int64_t> num_queries_peer_group;
int64_t total_num_queries = 0;
for (int i = 0; i < num_groups; i++) { for (int i = 0; i < num_groups; i++) {
auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); auto num_queries = GetNumOfQueries(c_placeholder_groups[i]);
num_queries_peer_group.push_back(num_queries); num_queries_peer_group.push_back(num_queries);
} }
int64_t fill_hit_offset = 0; int64_t last_offset = 0;
for (int i = 0; i < num_groups; i++) { for (int i = 0; i < num_groups; i++) {
MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i];
for (int j = 0; j < num_queries_peer_group[i]; j++) { hits_peer_group.hits_.resize(num_queries_peer_group[i]);
milvus::proto::milvus::Hits hits; hits_peer_group.blob_length_.resize(num_queries_peer_group[i]);
for (int k = 0; k < topk; k++, fill_hit_offset++) { std::vector<milvus::proto::milvus::Hits> hits(num_queries_peer_group[i]);
hits.add_scores(search_result->result_distances_[fill_hit_offset]); #pragma omp parallel for
auto& row_data = search_result->row_data_[fill_hit_offset]; for (int m = 0; m < num_queries_peer_group[i]; m++) {
hits.add_row_data(row_data.data(), row_data.size()); for (int n = 0; n < topk; n++) {
int64_t result_offset = last_offset + m * topk + n;
hits[m].add_scores(search_result->result_distances_[result_offset]);
auto& row_data = search_result->row_data_[result_offset];
hits[m].add_row_data(row_data.data(), row_data.size());
int64_t result_id; int64_t result_id;
memcpy(&result_id, row_data.data(), sizeof(int64_t)); memcpy(&result_id, row_data.data(), sizeof(int64_t));
hits.add_ids(result_id); hits[m].add_ids(result_id);
} }
auto blob = hits.SerializeAsString(); }
hits_peer_group.hits_.push_back(blob); last_offset = last_offset + num_queries_peer_group[i] * topk;
hits_peer_group.blob_length_.push_back(blob.size());
#pragma omp parallel for
for (int j = 0; j < num_queries_peer_group[i]; j++) {
auto blob = hits[j].SerializeAsString();
hits_peer_group.hits_[j] = blob;
hits_peer_group.blob_length_[j] = blob.size();
} }
} }

View File

@ -287,6 +287,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
if err != nil { if err != nil {
return err return err
} }
queryNum := searchReq.getNumOfQuery()
searchRequests := make([]*searchRequest, 0) searchRequests := make([]*searchRequest, 0)
searchRequests = append(searchRequests, searchReq) searchRequests = append(searchRequests, searchReq)
@ -315,6 +316,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
searchPartitionIDs = partitionIDsInQuery searchPartitionIDs = partitionIDsInQuery
} }
sp.LogFields(oplog.String("statistical time", "stats start"), oplog.Object("nq", queryNum), oplog.Object("dsl", dsl))
for _, partitionID := range searchPartitionIDs { for _, partitionID := range searchPartitionIDs {
segmentIDs, err := s.replica.getSegmentIDs(partitionID) segmentIDs, err := s.replica.getSegmentIDs(partitionID)
if err != nil { if err != nil {
@ -336,6 +338,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
} }
} }
sp.LogFields(oplog.String("statistical time", "segment search end"))
if len(searchResults) <= 0 { if len(searchResults) <= 0 {
for _, group := range searchRequests { for _, group := range searchRequests {
nq := group.getNumOfQuery() nq := group.getNumOfQuery()
@ -378,28 +381,34 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
if numSegment == 1 { if numSegment == 1 {
inReduced[0] = true inReduced[0] = true
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
if err != nil { if err != nil {
return err return err
} }
marshaledHits, err = reorganizeSingleQueryResult(plan, searchRequests, searchResults[0]) marshaledHits, err = reorganizeSingleQueryResult(plan, searchRequests, searchResults[0])
sp.LogFields(oplog.String("statistical time", "reorganizeSingleQueryResult end"))
if err != nil { if err != nil {
return err return err
} }
} else { } else {
err = reduceSearchResults(searchResults, numSegment, inReduced) err = reduceSearchResults(searchResults, numSegment, inReduced)
sp.LogFields(oplog.String("statistical time", "reduceSearchResults end"))
if err != nil { if err != nil {
return err return err
} }
err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced)
sp.LogFields(oplog.String("statistical time", "fillTargetEntry end"))
if err != nil { if err != nil {
return err return err
} }
marshaledHits, err = reorganizeQueryResults(plan, searchRequests, searchResults, numSegment, inReduced) marshaledHits, err = reorganizeQueryResults(plan, searchRequests, searchResults, numSegment, inReduced)
sp.LogFields(oplog.String("statistical time", "reorganizeQueryResults end"))
if err != nil { if err != nil {
return err return err
} }
} }
hitsBlob, err := marshaledHits.getHitsBlob() hitsBlob, err := marshaledHits.getHitsBlob()
sp.LogFields(oplog.String("statistical time", "getHitsBlob end"))
if err != nil { if err != nil {
return err return err
} }
@ -457,8 +466,10 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error {
} }
} }
sp.LogFields(oplog.String("statistical time", "before free c++ memory"))
deleteSearchResults(searchResults) deleteSearchResults(searchResults)
deleteMarshaledHits(marshaledHits) deleteMarshaledHits(marshaledHits)
sp.LogFields(oplog.String("statistical time", "stats done"))
plan.delete() plan.delete()
searchReq.delete() searchReq.delete()
return nil return nil