From 476cf61d9868d6d3a5f4b664656bb247fe44ed95 Mon Sep 17 00:00:00 2001 From: Spade A <71589810+SpadeA-Tang@users.noreply.github.com> Date: Wed, 26 Feb 2025 16:15:58 +0800 Subject: [PATCH] fix: random sample consider empty input (#40201) issue: #40198 Fix random sample does not consider empty input, that is no data is hit by filter expression. --------- Signed-off-by: SpadeA --- .../src/exec/operator/RandomSampleNode.cpp | 36 ++++++++++------- internal/core/unittest/test_random_sample.cpp | 40 +++++++++++++++++++ 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/internal/core/src/exec/operator/RandomSampleNode.cpp b/internal/core/src/exec/operator/RandomSampleNode.cpp index 4a38a3ad8c..b3d46f98f9 100644 --- a/internal/core/src/exec/operator/RandomSampleNode.cpp +++ b/internal/core/src/exec/operator/RandomSampleNode.cpp @@ -106,22 +106,27 @@ PhyRandomSampleNode::GetOutput() { TargetBitmapView input_data(input_col->GetRawData(), input_col->size()); // note: false means the elemnt is hit size_t input_false_count = input_data.size() - input_data.count(); - FixedVector pos{}; - pos.reserve(input_false_count); - auto value = input_data.find_first(false); - while (value.has_value()) { - auto offset = value.value(); - pos.push_back(offset); - value = input_data.find_next(offset, false); - } - assert(pos.size() == input_false_count); - input_data.set(); - auto sampled = Sample(input_false_count, factor_); - assert(sampled.back() < input_false_count); - for (auto i = 0; i < sampled.size(); ++i) { - input_data[pos[sampled[i]]] = false; + if (input_false_count > 0) { + FixedVector pos{}; + pos.reserve(input_false_count); + auto value = input_data.find_first(false); + while (value.has_value()) { + auto offset = value.value(); + pos.push_back(offset); + value = input_data.find_next(offset, false); + } + assert(pos.size() == input_false_count); + + input_data.set(); + auto sampled = Sample(input_false_count, factor_); + assert(sampled.back() < input_false_count); + for (auto i = 0; i < sampled.size(); ++i) { + input_data[pos[sampled[i]]] = false; + } } + + is_finished_ = true; return std::make_shared(std::vector{input_col}); } else { auto sample_output = std::make_shared( @@ -143,10 +148,11 @@ PhyRandomSampleNode::GetOutput() { data[n] = true; } - is_finished_ = true; if (need_flip) { data.flip(); } + + is_finished_ = true; return std::make_shared( std::vector{sample_output}); } diff --git a/internal/core/unittest/test_random_sample.cpp b/internal/core/unittest/test_random_sample.cpp index eb17129b62..7a8c712305 100644 --- a/internal/core/unittest/test_random_sample.cpp +++ b/internal/core/unittest/test_random_sample.cpp @@ -131,4 +131,44 @@ TEST_P(RandomSampleTest, SampleWithUnaryFiler) { int expected_size = static_cast(N * sample_factor) / 3; // We can accept size one difference due to the float point calculation in sampling. assert(expected_size - 1 <= data_size && data_size <= expected_size + 1); +} + +TEST(RandomSampleTest, SampleWithEmptyInput) { + double sample_factor = 0.1; + + auto schema = std::make_shared(); + auto fid_64 = schema->AddDebugField("i64", DataType::INT64); + schema->set_primary_field_id(fid_64); + fid_64 = schema->AddDebugField("integer", DataType::INT64); + + const int64_t N = 3000; + auto dataset = DataGen(schema, N); + auto segment = CreateSealedSegment(schema); + + SealedLoadFieldData(dataset, *segment); + + milvus::proto::plan::GenericValue val; + val.set_int64_val(0); + // Less than 0 will not match any data + auto expr = std::make_shared( + milvus::expr::ColumnInfo( + fid_64, DataType::INT64, std::vector()), + OpType::LessThan, + val, + std::vector{}); + auto plan = std::make_unique(*schema); + plan->plan_node_ = std::make_unique(); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanForRandomSample(sample_factor, expr); + std::vector target_offsets{fid_64}; + plan->field_ids_ = target_offsets; + + auto retrieve_results = RetrieveWithDefaultOutputSizeAndLargeTimestamp( + segment.get(), plan.get()); + Assert(retrieve_results->fields_data_size() == target_offsets.size()); + auto field = retrieve_results->fields_data(0); + auto field_data = field.scalars().long_data(); + int data_size = field.scalars().long_data().data_size(); + + assert(data_size == 0); } \ No newline at end of file