diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index c84bf29697..af13ede75b 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -64,6 +64,7 @@ GetResultData(std::vector>& search_records, auto num_segments = search_results.size(); AssertInfo(num_segments > 0, "num segment must greater than 0"); + std::unordered_set pk_set; int64_t skip_dup_cnt = 0; for (int64_t qi = 0; qi < nq; qi++) { std::vector result_pairs; @@ -86,38 +87,25 @@ GetResultData(std::vector>& search_records, search_records[index].push_back(result_pair.offset_++); } #else - float prev_dis = MAXFLOAT; - std::unordered_set prev_pk_set; + pk_set.clear(); while (curr_offset - base_offset < topk) { result_pairs[0].reset_distance(); std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>()); auto& result_pair = result_pairs[0]; auto index = result_pair.index_; int64_t curr_pk = result_pair.search_result_->primary_keys_[result_pair.offset_]; - float curr_dis = result_pair.search_result_->result_distances_[result_pair.offset_]; // remove duplicates - if (curr_pk == INVALID_ID || std::abs(curr_dis - prev_dis) > 0.00001) { + if (curr_pk == INVALID_ID || pk_set.count(curr_pk) == 0) { result_pair.search_result_->result_offsets_.push_back(curr_offset++); - search_records[index].push_back(result_pair.offset_); - prev_dis = curr_dis; - prev_pk_set.clear(); - prev_pk_set.insert(curr_pk); - } else { - // To handle this case: - // e1: [100, 0.99] - // e2: [101, 0.99] ==> not duplicated, should keep - // e3: [100, 0.99] ==> duplicated, should remove - if (prev_pk_set.count(curr_pk) == 0) { - result_pair.search_result_->result_offsets_.push_back(curr_offset++); - search_records[index].push_back(result_pair.offset_); - // prev_pk_set keeps all primary keys with same distance - prev_pk_set.insert(curr_pk); - } else { - // the entity with same distance and same primary key must be duplicated - skip_dup_cnt++; + search_records[index].push_back(result_pair.offset_++); + if (curr_pk != INVALID_ID) { + pk_set.insert(curr_pk); } + } else { + // skip entity with same primary key + result_pair.offset_++; + skip_dup_cnt++; } - result_pair.offset_++; } #endif } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index bb8a2f2179..26090d5e59 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1812,8 +1812,7 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in for i := int64(0); i < nq; i++ { offsets := make([]int64, len(searchResultData)) - var prevIDSet = make(map[int64]struct{}) - var prevScore float32 = math.MaxFloat32 + var idSet = make(map[int64]struct{}) var j int64 for j = 0; j < topk; { sel := selectSearchResultData(searchResultData, offsets, topk, i) @@ -1830,28 +1829,15 @@ func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq in } // remove duplicates - if math.Abs(float64(score)-float64(prevScore)) > 0.00001 { + if _, ok := idSet[id]; !ok { typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx) ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id) ret.Results.Scores = append(ret.Results.Scores, score) - prevScore = score - prevIDSet = map[int64]struct{}{id: {}} + idSet[id] = struct{}{} j++ } else { - // To handle this case: - // e1: [100, 0.99] - // e2: [101, 0.99] ==> not duplicated, should keep - // e3: [100, 0.99] ==> duplicated, should remove - if _, ok := prevIDSet[id]; !ok { - typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx) - ret.Results.Ids.GetIntId().Data = append(ret.Results.Ids.GetIntId().Data, id) - ret.Results.Scores = append(ret.Results.Scores, score) - prevIDSet[id] = struct{}{} - j++ - } else { - // entity with same id and same score must be duplicated - skipDupCnt++ - } + // skip entity with same id + skipDupCnt++ } offsets[sel]++ }