mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
fix: random sample consider empty input (#40201)
issue: #40198 Fix random sample does not consider empty input, that is no data is hit by filter expression. --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
This commit is contained in:
parent
8f077089ba
commit
476cf61d98
@ -106,22 +106,27 @@ PhyRandomSampleNode::GetOutput() {
|
||||
TargetBitmapView input_data(input_col->GetRawData(), input_col->size());
|
||||
// note: false means the elemnt is hit
|
||||
size_t input_false_count = input_data.size() - input_data.count();
|
||||
FixedVector<uint32_t> pos{};
|
||||
pos.reserve(input_false_count);
|
||||
auto value = input_data.find_first(false);
|
||||
while (value.has_value()) {
|
||||
auto offset = value.value();
|
||||
pos.push_back(offset);
|
||||
value = input_data.find_next(offset, false);
|
||||
}
|
||||
assert(pos.size() == input_false_count);
|
||||
|
||||
input_data.set();
|
||||
auto sampled = Sample(input_false_count, factor_);
|
||||
assert(sampled.back() < input_false_count);
|
||||
for (auto i = 0; i < sampled.size(); ++i) {
|
||||
input_data[pos[sampled[i]]] = false;
|
||||
if (input_false_count > 0) {
|
||||
FixedVector<uint32_t> pos{};
|
||||
pos.reserve(input_false_count);
|
||||
auto value = input_data.find_first(false);
|
||||
while (value.has_value()) {
|
||||
auto offset = value.value();
|
||||
pos.push_back(offset);
|
||||
value = input_data.find_next(offset, false);
|
||||
}
|
||||
assert(pos.size() == input_false_count);
|
||||
|
||||
input_data.set();
|
||||
auto sampled = Sample(input_false_count, factor_);
|
||||
assert(sampled.back() < input_false_count);
|
||||
for (auto i = 0; i < sampled.size(); ++i) {
|
||||
input_data[pos[sampled[i]]] = false;
|
||||
}
|
||||
}
|
||||
|
||||
is_finished_ = true;
|
||||
return std::make_shared<RowVector>(std::vector<VectorPtr>{input_col});
|
||||
} else {
|
||||
auto sample_output = std::make_shared<ColumnVector>(
|
||||
@ -143,10 +148,11 @@ PhyRandomSampleNode::GetOutput() {
|
||||
data[n] = true;
|
||||
}
|
||||
|
||||
is_finished_ = true;
|
||||
if (need_flip) {
|
||||
data.flip();
|
||||
}
|
||||
|
||||
is_finished_ = true;
|
||||
return std::make_shared<RowVector>(
|
||||
std::vector<VectorPtr>{sample_output});
|
||||
}
|
||||
|
||||
@ -131,4 +131,44 @@ TEST_P(RandomSampleTest, SampleWithUnaryFiler) {
|
||||
int expected_size = static_cast<int>(N * sample_factor) / 3;
|
||||
// We can accept size one difference due to the float point calculation in sampling.
|
||||
assert(expected_size - 1 <= data_size && data_size <= expected_size + 1);
|
||||
}
|
||||
|
||||
TEST(RandomSampleTest, SampleWithEmptyInput) {
|
||||
double sample_factor = 0.1;
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto fid_64 = schema->AddDebugField("i64", DataType::INT64);
|
||||
schema->set_primary_field_id(fid_64);
|
||||
fid_64 = schema->AddDebugField("integer", DataType::INT64);
|
||||
|
||||
const int64_t N = 3000;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
|
||||
SealedLoadFieldData(dataset, *segment);
|
||||
|
||||
milvus::proto::plan::GenericValue val;
|
||||
val.set_int64_val(0);
|
||||
// Less than 0 will not match any data
|
||||
auto expr = std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
|
||||
milvus::expr::ColumnInfo(
|
||||
fid_64, DataType::INT64, std::vector<std::string>()),
|
||||
OpType::LessThan,
|
||||
val,
|
||||
std::vector<proto::plan::GenericValue>{});
|
||||
auto plan = std::make_unique<query::RetrievePlan>(*schema);
|
||||
plan->plan_node_ = std::make_unique<query::RetrievePlanNode>();
|
||||
plan->plan_node_->plannodes_ =
|
||||
milvus::test::CreateRetrievePlanForRandomSample(sample_factor, expr);
|
||||
std::vector<FieldId> target_offsets{fid_64};
|
||||
plan->field_ids_ = target_offsets;
|
||||
|
||||
auto retrieve_results = RetrieveWithDefaultOutputSizeAndLargeTimestamp(
|
||||
segment.get(), plan.get());
|
||||
Assert(retrieve_results->fields_data_size() == target_offsets.size());
|
||||
auto field = retrieve_results->fields_data(0);
|
||||
auto field_data = field.scalars().long_data();
|
||||
int data_size = field.scalars().long_data().data_size();
|
||||
|
||||
assert(data_size == 0);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user