fix:add convert int to float function to array_contains related expr (#43468)

#43281

Signed-off-by: luzhang <luzhang@zilliz.com>
Co-authored-by: luzhang <luzhang@zilliz.com>
This commit is contained in:
zhagnlu 2025-07-23 15:20:53 +08:00 committed by GitHub
parent 4db877f76c
commit d64dceea47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 15 deletions

View File

@ -55,7 +55,7 @@ class SingleElement : public BaseElement {
template <typename T>
void
SetValue(const proto::plan::GenericValue& value) {
value_ = GetValueFromProto<T>(value);
value_ = GetValueWithCastNumber<T>(value);
}
template <typename T>
@ -122,7 +122,7 @@ class SortVectorElement : public MultiElement {
explicit SortVectorElement(
const std::vector<proto::plan::GenericValue>& values) {
for (auto& value : values) {
values_.push_back(GetValueFromProto<T>(value));
values_.push_back(GetValueWithCastNumber<T>(value));
}
std::sort(values_.begin(), values_.end());
sorted_ = true;
@ -178,7 +178,7 @@ class FlatVectorElement : public MultiElement {
explicit FlatVectorElement(
const std::vector<proto::plan::GenericValue>& values) {
for (auto& value : values) {
values_.push_back(GetValueFromProto<T>(value));
values_.push_back(GetValueWithCastNumber<T>(value));
}
}
@ -223,7 +223,7 @@ class SetElement : public MultiElement {
public:
explicit SetElement(const std::vector<proto::plan::GenericValue>& values) {
for (auto& value : values) {
values_.insert(GetValueFromProto<T>(value));
values_.insert(GetValueWithCastNumber<T>(value));
}
}

View File

@ -84,18 +84,23 @@ PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment(EvalCtx& context) {
case proto::plan::JSONContainsExpr_JSONOp_Contains:
case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: {
if (IsArrayDataType(data_type)) {
auto val_type = expr_->vals_[0].val_case();
auto val_type = expr_->column_.element_type_;
switch (val_type) {
case proto::plan::GenericValue::kBoolVal: {
case DataType::BOOL: {
return ExecArrayContains<bool>(context);
}
case proto::plan::GenericValue::kInt64Val: {
case DataType::INT8:
case DataType::INT16:
case DataType::INT32:
case DataType::INT64: {
return ExecArrayContains<int64_t>(context);
}
case proto::plan::GenericValue::kFloatVal: {
case DataType::FLOAT:
case DataType::DOUBLE: {
return ExecArrayContains<double>(context);
}
case proto::plan::GenericValue::kStringVal: {
case DataType::STRING:
case DataType::VARCHAR: {
return ExecArrayContains<std::string>(context);
}
default:

View File

@ -138,6 +138,18 @@ struct ColumnInfo {
nullable_(nullable) {
}
ColumnInfo(FieldId field_id,
DataType data_type,
DataType element_type,
std::vector<std::string> nested_path = {},
bool nullable = false)
: field_id_(field_id),
data_type_(data_type),
element_type_(element_type),
nested_path_(std::move(nested_path)),
nullable_(nullable) {
}
bool
operator==(const ColumnInfo& other) {
if (field_id_ != other.field_id_) {

View File

@ -1059,7 +1059,7 @@ TEST(Expr, TestArrayContains) {
}
auto start = std::chrono::steady_clock::now();
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
expr::ColumnInfo(bool_array_fid, DataType::ARRAY),
expr::ColumnInfo(bool_array_fid, DataType::ARRAY, DataType::BOOL),
proto::plan::JSONContainsExpr_JSONOp_Contains,
true,
values);
@ -1130,7 +1130,8 @@ TEST(Expr, TestArrayContains) {
}
auto start = std::chrono::steady_clock::now();
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
expr::ColumnInfo(double_array_fid, DataType::ARRAY),
expr::ColumnInfo(
double_array_fid, DataType::ARRAY, DataType::DOUBLE),
proto::plan::JSONContainsExpr_JSONOp_Contains,
true,
values);
@ -1191,7 +1192,7 @@ TEST(Expr, TestArrayContains) {
}
auto start = std::chrono::steady_clock::now();
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
expr::ColumnInfo(float_array_fid, DataType::ARRAY),
expr::ColumnInfo(float_array_fid, DataType::ARRAY, DataType::FLOAT),
proto::plan::JSONContainsExpr_JSONOp_Contains,
true,
values);
@ -1262,7 +1263,7 @@ TEST(Expr, TestArrayContains) {
}
auto start = std::chrono::steady_clock::now();
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
expr::ColumnInfo(int_array_fid, DataType::ARRAY),
expr::ColumnInfo(int_array_fid, DataType::ARRAY, DataType::INT8),
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
true,
values);
@ -1324,7 +1325,7 @@ TEST(Expr, TestArrayContains) {
}
auto start = std::chrono::steady_clock::now();
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
expr::ColumnInfo(long_array_fid, DataType::ARRAY),
expr::ColumnInfo(long_array_fid, DataType::ARRAY, DataType::INT64),
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
true,
values);
@ -1393,7 +1394,8 @@ TEST(Expr, TestArrayContains) {
}
auto start = std::chrono::steady_clock::now();
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
expr::ColumnInfo(string_array_fid, DataType::ARRAY),
expr::ColumnInfo(
string_array_fid, DataType::ARRAY, DataType::VARCHAR),
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
true,
values);

View File

@ -463,6 +463,34 @@ class TestCollectionSearchJSON(TestcaseBase):
exp_ids = cf.assert_json_contains(expression, string_field_value)
assert set(res[0].ids) == set(exp_ids)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr_prefix", ["array_contains_any", "ARRAY_CONTAINS_ANY",
"not array_contains_any", "not ARRAY_CONTAINS_ANY"])
def test_search_expr_array_contains_any_with_float_field(self, expr_prefix):
"""
target: test query with expression using array_contains with float field
method: query with expression using array_contains with float field
expected: succeed
"""
# 1. create a collection
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema)
# 2. insert data
float_field_value = [[random.random() for j in range(i, i + 3)] for i in range(ct.default_nb)]
data = cf.gen_array_dataframe_data()
data[ct.default_float_array_field_name] = float_field_value
collection_w.insert(data)
collection_w.create_index(ct.default_float_vec_field_name, {})
# 3. search with array_contains_any with float and int target
collection_w.load()
expression = f"{expr_prefix}({ct.default_float_array_field_name}, [0.5, 0.6, 1, 2])"
res = collection_w.search(vectors[:default_nq], default_search_field, {},
limit=ct.default_nb, expr=expression)[0]
exp_ids = cf.assert_json_contains(expression, float_field_value)
assert set(res[0].ids) == set(exp_ids)
@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("expr_prefix", ["array_contains_all", "ARRAY_CONTAINS_ALL",
"array_contains_any", "ARRAY_CONTAINS_ANY"])