Fix reduce decreasing recall (#21981)

Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
yah01 2023-02-06 11:23:53 +08:00 committed by GitHub
parent e25f987a5c
commit 73ce87dfe5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 12 deletions

View File

@ -9,15 +9,17 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <cstdint>
#include <vector>
#include <algorithm>
#include "Reduce.h"
#include <log/Log.h>
#include "Reduce.h"
#include "pkVisitor.h"
#include <algorithm>
#include <cstdint>
#include <vector>
#include "SegmentInterface.h"
#include "Utils.h"
#include "pkVisitor.h"
namespace milvus::segcore {
@ -160,6 +162,7 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offs
heap_.pop();
}
pk_set_.clear();
pairs_.clear();
pairs_.reserve(num_segments_);
for (int i = 0; i < num_segments_; i++) {
@ -183,7 +186,7 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offs
int64_t dup_cnt = 0;
auto start = offset;
while (offset - start < topk) {
while (offset - start < topk && !heap_.empty()) {
auto pilot = heap_.top();
heap_.pop();
@ -203,7 +206,9 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offs
dup_cnt++;
}
pilot->advance();
heap_.push(pilot);
if (pilot->primary_key_ != INVALID_PK) {
heap_.push(pilot);
}
}
return dup_cnt;
}

View File

@ -50,7 +50,7 @@ struct SearchResultPair {
distance_ = search_result_->distances_.at(offset_);
} else {
primary_key_ = INVALID_PK;
distance_ = std::numeric_limits<float>::max();
distance_ = std::numeric_limits<float>::min();
}
}
};
@ -58,6 +58,6 @@ struct SearchResultPair {
struct SearchResultPairComparator {
bool
operator()(const SearchResultPair* lhs, const SearchResultPair* rhs) const {
return *lhs > *rhs;
return *rhs > *lhs;
}
};

View File

@ -16,10 +16,10 @@
TEST(SearchResultPair, Greater) {
auto pair1 = SearchResultPair(0, 1.0, nullptr, 0, 0, 1);
auto pair2 = SearchResultPair(1, 2.0, nullptr, 1, 0, 10);
auto pair2 = SearchResultPair(1, 2.0, nullptr, 1, 0, 1);
ASSERT_EQ(pair1 > pair2, false);
pair1.advance();
pair2.advance();
ASSERT_EQ(pair1 > pair2, true);
ASSERT_EQ(pair1.primary_key_, INVALID_PK);
ASSERT_EQ(pair2.primary_key_, INVALID_PK);
}