diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index e8ae1fc32b..ab95dd6f1b 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -65,6 +65,9 @@ set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download ) set(FETCHCONTENT_QUIET OFF) include( ThirdPartyPackages ) find_package(OpenMP REQUIRED) +if (OPENMP_FOUND) + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +endif() # **************************** Compiler arguments **************************** message( STATUS "Building Milvus CPU version" ) diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index 155227608f..ebac67a194 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -33,7 +33,7 @@ add_library(milvus_segcore SHARED ) target_link_libraries(milvus_segcore - tbb milvus_utils pthread knowhere log milvus_proto + tbb milvus_utils pthread knowhere log milvus_proto ${OpenMP_CXX_FLAGS} dl backtrace milvus_common milvus_query diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index 0fe4d1725a..990c2f1549 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -172,19 +172,20 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits, try { auto marshaledHits = std::make_unique(num_groups); auto topk = GetTopK(c_plan); - std::vector num_queries_peer_group; + std::vector num_queries_peer_group(num_groups); int64_t total_num_queries = 0; for (int i = 0; i < num_groups; i++) { auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); - num_queries_peer_group.push_back(num_queries); + num_queries_peer_group[i] = num_queries; total_num_queries += num_queries; } std::vector result_distances(total_num_queries * topk); std::vector result_ids(total_num_queries * topk); std::vector> row_datas(total_num_queries * topk); + std::vector temp_ids; - int64_t count = 0; + std::vector counts(num_segments); for (int i = 0; i < num_segments; i++) { if (is_selected[i] == false) { continue; @@ -192,30 +193,46 @@ ReorganizeQueryResults(CMarshaledHits* c_marshaled_hits, auto search_result = (SearchResult*)c_search_results[i]; AssertInfo(search_result != nullptr, "search result must not equal to nullptr"); auto size = search_result->result_offsets_.size(); +#pragma omp parallel for for (int j = 0; j < size; j++) { auto loc = search_result->result_offsets_[j]; result_distances[loc] = search_result->result_distances_[j]; row_datas[loc] = search_result->row_data_[j]; memcpy(&result_ids[loc], search_result->row_data_[j].data(), sizeof(int64_t)); } - count += size; + counts[i] = size; } - AssertInfo(count == total_num_queries * topk, "the reduces result's size less than total_num_queries*topk"); - int64_t fill_hit_offset = 0; + int64_t total_count = 0; + for (int i = 0; i < num_segments; i++) { + total_count += counts[i]; + } + AssertInfo(total_count == total_num_queries * topk, + "the reduces result's size less than total_num_queries*topk"); + + int64_t last_offset = 0; for (int i = 0; i < num_groups; i++) { MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; - for (int j = 0; j < num_queries_peer_group[i]; j++) { - milvus::proto::milvus::Hits hits; - for (int k = 0; k < topk; k++, fill_hit_offset++) { - hits.add_ids(result_ids[fill_hit_offset]); - hits.add_scores(result_distances[fill_hit_offset]); - auto& row_data = row_datas[fill_hit_offset]; - hits.add_row_data(row_data.data(), row_data.size()); + hits_peer_group.hits_.resize(num_queries_peer_group[i]); + hits_peer_group.blob_length_.resize(num_queries_peer_group[i]); + std::vector hits(num_queries_peer_group[i]); +#pragma omp parallel for + for (int m = 0; m < num_queries_peer_group[i]; m++) { + for (int n = 0; n < topk; n++) { + int64_t result_offset = last_offset + m * topk + n; + hits[m].add_ids(result_ids[result_offset]); + hits[m].add_scores(result_distances[result_offset]); + auto& row_data = row_datas[result_offset]; + hits[m].add_row_data(row_data.data(), row_data.size()); } - auto blob = hits.SerializeAsString(); - hits_peer_group.hits_.push_back(blob); - hits_peer_group.blob_length_.push_back(blob.size()); + } + last_offset = last_offset + num_queries_peer_group[i] * topk; + +#pragma omp parallel for + for (int j = 0; j < num_queries_peer_group[i]; j++) { + auto blob = hits[j].SerializeAsString(); + hits_peer_group.hits_[j] = blob; + hits_peer_group.blob_length_[j] = blob.size(); } } @@ -245,27 +262,37 @@ ReorganizeSingleQueryResult(CMarshaledHits* c_marshaled_hits, auto search_result = (SearchResult*)c_search_result; auto topk = GetTopK(c_plan); std::vector num_queries_peer_group; + int64_t total_num_queries = 0; for (int i = 0; i < num_groups; i++) { auto num_queries = GetNumOfQueries(c_placeholder_groups[i]); num_queries_peer_group.push_back(num_queries); } - int64_t fill_hit_offset = 0; + int64_t last_offset = 0; for (int i = 0; i < num_groups; i++) { MarshaledHitsPeerGroup& hits_peer_group = (*marshaledHits).marshaled_hits_[i]; - for (int j = 0; j < num_queries_peer_group[i]; j++) { - milvus::proto::milvus::Hits hits; - for (int k = 0; k < topk; k++, fill_hit_offset++) { - hits.add_scores(search_result->result_distances_[fill_hit_offset]); - auto& row_data = search_result->row_data_[fill_hit_offset]; - hits.add_row_data(row_data.data(), row_data.size()); + hits_peer_group.hits_.resize(num_queries_peer_group[i]); + hits_peer_group.blob_length_.resize(num_queries_peer_group[i]); + std::vector hits(num_queries_peer_group[i]); +#pragma omp parallel for + for (int m = 0; m < num_queries_peer_group[i]; m++) { + for (int n = 0; n < topk; n++) { + int64_t result_offset = last_offset + m * topk + n; + hits[m].add_scores(search_result->result_distances_[result_offset]); + auto& row_data = search_result->row_data_[result_offset]; + hits[m].add_row_data(row_data.data(), row_data.size()); int64_t result_id; memcpy(&result_id, row_data.data(), sizeof(int64_t)); - hits.add_ids(result_id); + hits[m].add_ids(result_id); } - auto blob = hits.SerializeAsString(); - hits_peer_group.hits_.push_back(blob); - hits_peer_group.blob_length_.push_back(blob.size()); + } + last_offset = last_offset + num_queries_peer_group[i] * topk; + +#pragma omp parallel for + for (int j = 0; j < num_queries_peer_group[i]; j++) { + auto blob = hits[j].SerializeAsString(); + hits_peer_group.hits_[j] = blob; + hits_peer_group.blob_length_[j] = blob.size(); } } diff --git a/internal/querynode/search_collection.go b/internal/querynode/search_collection.go index 24c770609d..e46d44e92a 100644 --- a/internal/querynode/search_collection.go +++ b/internal/querynode/search_collection.go @@ -287,6 +287,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { if err != nil { return err } + queryNum := searchReq.getNumOfQuery() searchRequests := make([]*searchRequest, 0) searchRequests = append(searchRequests, searchReq) @@ -315,6 +316,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { searchPartitionIDs = partitionIDsInQuery } + sp.LogFields(oplog.String("statistical time", "stats start"), oplog.Object("nq", queryNum), oplog.Object("dsl", dsl)) for _, partitionID := range searchPartitionIDs { segmentIDs, err := s.replica.getSegmentIDs(partitionID) if err != nil { @@ -336,6 +338,7 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { } } + sp.LogFields(oplog.String("statistical time", "segment search end")) if len(searchResults) <= 0 { for _, group := range searchRequests { nq := group.getNumOfQuery() @@ -378,28 +381,34 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { if numSegment == 1 { inReduced[0] = true err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) + sp.LogFields(oplog.String("statistical time", "fillTargetEntry end")) if err != nil { return err } marshaledHits, err = reorganizeSingleQueryResult(plan, searchRequests, searchResults[0]) + sp.LogFields(oplog.String("statistical time", "reorganizeSingleQueryResult end")) if err != nil { return err } } else { err = reduceSearchResults(searchResults, numSegment, inReduced) + sp.LogFields(oplog.String("statistical time", "reduceSearchResults end")) if err != nil { return err } err = fillTargetEntry(plan, searchResults, matchedSegments, inReduced) + sp.LogFields(oplog.String("statistical time", "fillTargetEntry end")) if err != nil { return err } marshaledHits, err = reorganizeQueryResults(plan, searchRequests, searchResults, numSegment, inReduced) + sp.LogFields(oplog.String("statistical time", "reorganizeQueryResults end")) if err != nil { return err } } hitsBlob, err := marshaledHits.getHitsBlob() + sp.LogFields(oplog.String("statistical time", "getHitsBlob end")) if err != nil { return err } @@ -457,8 +466,10 @@ func (s *searchCollection) search(searchMsg *msgstream.SearchMsg) error { } } + sp.LogFields(oplog.String("statistical time", "before free c++ memory")) deleteSearchResults(searchResults) deleteMarshaledHits(marshaledHits) + sp.LogFields(oplog.String("statistical time", "stats done")) plan.delete() searchReq.delete() return nil