mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
fix:add convert int to float function to array_contains related expr (#43468)
#43281 Signed-off-by: luzhang <luzhang@zilliz.com> Co-authored-by: luzhang <luzhang@zilliz.com>
This commit is contained in:
parent
4db877f76c
commit
d64dceea47
@ -55,7 +55,7 @@ class SingleElement : public BaseElement {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
void
|
void
|
||||||
SetValue(const proto::plan::GenericValue& value) {
|
SetValue(const proto::plan::GenericValue& value) {
|
||||||
value_ = GetValueFromProto<T>(value);
|
value_ = GetValueWithCastNumber<T>(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -122,7 +122,7 @@ class SortVectorElement : public MultiElement {
|
|||||||
explicit SortVectorElement(
|
explicit SortVectorElement(
|
||||||
const std::vector<proto::plan::GenericValue>& values) {
|
const std::vector<proto::plan::GenericValue>& values) {
|
||||||
for (auto& value : values) {
|
for (auto& value : values) {
|
||||||
values_.push_back(GetValueFromProto<T>(value));
|
values_.push_back(GetValueWithCastNumber<T>(value));
|
||||||
}
|
}
|
||||||
std::sort(values_.begin(), values_.end());
|
std::sort(values_.begin(), values_.end());
|
||||||
sorted_ = true;
|
sorted_ = true;
|
||||||
@ -178,7 +178,7 @@ class FlatVectorElement : public MultiElement {
|
|||||||
explicit FlatVectorElement(
|
explicit FlatVectorElement(
|
||||||
const std::vector<proto::plan::GenericValue>& values) {
|
const std::vector<proto::plan::GenericValue>& values) {
|
||||||
for (auto& value : values) {
|
for (auto& value : values) {
|
||||||
values_.push_back(GetValueFromProto<T>(value));
|
values_.push_back(GetValueWithCastNumber<T>(value));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,7 +223,7 @@ class SetElement : public MultiElement {
|
|||||||
public:
|
public:
|
||||||
explicit SetElement(const std::vector<proto::plan::GenericValue>& values) {
|
explicit SetElement(const std::vector<proto::plan::GenericValue>& values) {
|
||||||
for (auto& value : values) {
|
for (auto& value : values) {
|
||||||
values_.insert(GetValueFromProto<T>(value));
|
values_.insert(GetValueWithCastNumber<T>(value));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -84,18 +84,23 @@ PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment(EvalCtx& context) {
|
|||||||
case proto::plan::JSONContainsExpr_JSONOp_Contains:
|
case proto::plan::JSONContainsExpr_JSONOp_Contains:
|
||||||
case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: {
|
case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: {
|
||||||
if (IsArrayDataType(data_type)) {
|
if (IsArrayDataType(data_type)) {
|
||||||
auto val_type = expr_->vals_[0].val_case();
|
auto val_type = expr_->column_.element_type_;
|
||||||
switch (val_type) {
|
switch (val_type) {
|
||||||
case proto::plan::GenericValue::kBoolVal: {
|
case DataType::BOOL: {
|
||||||
return ExecArrayContains<bool>(context);
|
return ExecArrayContains<bool>(context);
|
||||||
}
|
}
|
||||||
case proto::plan::GenericValue::kInt64Val: {
|
case DataType::INT8:
|
||||||
|
case DataType::INT16:
|
||||||
|
case DataType::INT32:
|
||||||
|
case DataType::INT64: {
|
||||||
return ExecArrayContains<int64_t>(context);
|
return ExecArrayContains<int64_t>(context);
|
||||||
}
|
}
|
||||||
case proto::plan::GenericValue::kFloatVal: {
|
case DataType::FLOAT:
|
||||||
|
case DataType::DOUBLE: {
|
||||||
return ExecArrayContains<double>(context);
|
return ExecArrayContains<double>(context);
|
||||||
}
|
}
|
||||||
case proto::plan::GenericValue::kStringVal: {
|
case DataType::STRING:
|
||||||
|
case DataType::VARCHAR: {
|
||||||
return ExecArrayContains<std::string>(context);
|
return ExecArrayContains<std::string>(context);
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|||||||
@ -138,6 +138,18 @@ struct ColumnInfo {
|
|||||||
nullable_(nullable) {
|
nullable_(nullable) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ColumnInfo(FieldId field_id,
|
||||||
|
DataType data_type,
|
||||||
|
DataType element_type,
|
||||||
|
std::vector<std::string> nested_path = {},
|
||||||
|
bool nullable = false)
|
||||||
|
: field_id_(field_id),
|
||||||
|
data_type_(data_type),
|
||||||
|
element_type_(element_type),
|
||||||
|
nested_path_(std::move(nested_path)),
|
||||||
|
nullable_(nullable) {
|
||||||
|
}
|
||||||
|
|
||||||
bool
|
bool
|
||||||
operator==(const ColumnInfo& other) {
|
operator==(const ColumnInfo& other) {
|
||||||
if (field_id_ != other.field_id_) {
|
if (field_id_ != other.field_id_) {
|
||||||
|
|||||||
@ -1059,7 +1059,7 @@ TEST(Expr, TestArrayContains) {
|
|||||||
}
|
}
|
||||||
auto start = std::chrono::steady_clock::now();
|
auto start = std::chrono::steady_clock::now();
|
||||||
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
||||||
expr::ColumnInfo(bool_array_fid, DataType::ARRAY),
|
expr::ColumnInfo(bool_array_fid, DataType::ARRAY, DataType::BOOL),
|
||||||
proto::plan::JSONContainsExpr_JSONOp_Contains,
|
proto::plan::JSONContainsExpr_JSONOp_Contains,
|
||||||
true,
|
true,
|
||||||
values);
|
values);
|
||||||
@ -1130,7 +1130,8 @@ TEST(Expr, TestArrayContains) {
|
|||||||
}
|
}
|
||||||
auto start = std::chrono::steady_clock::now();
|
auto start = std::chrono::steady_clock::now();
|
||||||
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
||||||
expr::ColumnInfo(double_array_fid, DataType::ARRAY),
|
expr::ColumnInfo(
|
||||||
|
double_array_fid, DataType::ARRAY, DataType::DOUBLE),
|
||||||
proto::plan::JSONContainsExpr_JSONOp_Contains,
|
proto::plan::JSONContainsExpr_JSONOp_Contains,
|
||||||
true,
|
true,
|
||||||
values);
|
values);
|
||||||
@ -1191,7 +1192,7 @@ TEST(Expr, TestArrayContains) {
|
|||||||
}
|
}
|
||||||
auto start = std::chrono::steady_clock::now();
|
auto start = std::chrono::steady_clock::now();
|
||||||
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
||||||
expr::ColumnInfo(float_array_fid, DataType::ARRAY),
|
expr::ColumnInfo(float_array_fid, DataType::ARRAY, DataType::FLOAT),
|
||||||
proto::plan::JSONContainsExpr_JSONOp_Contains,
|
proto::plan::JSONContainsExpr_JSONOp_Contains,
|
||||||
true,
|
true,
|
||||||
values);
|
values);
|
||||||
@ -1262,7 +1263,7 @@ TEST(Expr, TestArrayContains) {
|
|||||||
}
|
}
|
||||||
auto start = std::chrono::steady_clock::now();
|
auto start = std::chrono::steady_clock::now();
|
||||||
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
||||||
expr::ColumnInfo(int_array_fid, DataType::ARRAY),
|
expr::ColumnInfo(int_array_fid, DataType::ARRAY, DataType::INT8),
|
||||||
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
|
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
|
||||||
true,
|
true,
|
||||||
values);
|
values);
|
||||||
@ -1324,7 +1325,7 @@ TEST(Expr, TestArrayContains) {
|
|||||||
}
|
}
|
||||||
auto start = std::chrono::steady_clock::now();
|
auto start = std::chrono::steady_clock::now();
|
||||||
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
||||||
expr::ColumnInfo(long_array_fid, DataType::ARRAY),
|
expr::ColumnInfo(long_array_fid, DataType::ARRAY, DataType::INT64),
|
||||||
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
|
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
|
||||||
true,
|
true,
|
||||||
values);
|
values);
|
||||||
@ -1393,7 +1394,8 @@ TEST(Expr, TestArrayContains) {
|
|||||||
}
|
}
|
||||||
auto start = std::chrono::steady_clock::now();
|
auto start = std::chrono::steady_clock::now();
|
||||||
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
auto expr = std::make_shared<milvus::expr::JsonContainsExpr>(
|
||||||
expr::ColumnInfo(string_array_fid, DataType::ARRAY),
|
expr::ColumnInfo(
|
||||||
|
string_array_fid, DataType::ARRAY, DataType::VARCHAR),
|
||||||
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
|
proto::plan::JSONContainsExpr_JSONOp_ContainsAll,
|
||||||
true,
|
true,
|
||||||
values);
|
values);
|
||||||
|
|||||||
@ -463,6 +463,34 @@ class TestCollectionSearchJSON(TestcaseBase):
|
|||||||
exp_ids = cf.assert_json_contains(expression, string_field_value)
|
exp_ids = cf.assert_json_contains(expression, string_field_value)
|
||||||
assert set(res[0].ids) == set(exp_ids)
|
assert set(res[0].ids) == set(exp_ids)
|
||||||
|
|
||||||
|
@pytest.mark.tags(CaseLabel.L1)
|
||||||
|
@pytest.mark.parametrize("expr_prefix", ["array_contains_any", "ARRAY_CONTAINS_ANY",
|
||||||
|
"not array_contains_any", "not ARRAY_CONTAINS_ANY"])
|
||||||
|
def test_search_expr_array_contains_any_with_float_field(self, expr_prefix):
|
||||||
|
"""
|
||||||
|
target: test query with expression using array_contains with float field
|
||||||
|
method: query with expression using array_contains with float field
|
||||||
|
expected: succeed
|
||||||
|
"""
|
||||||
|
# 1. create a collection
|
||||||
|
schema = cf.gen_array_collection_schema()
|
||||||
|
collection_w = self.init_collection_wrap(schema=schema)
|
||||||
|
|
||||||
|
# 2. insert data
|
||||||
|
float_field_value = [[random.random() for j in range(i, i + 3)] for i in range(ct.default_nb)]
|
||||||
|
data = cf.gen_array_dataframe_data()
|
||||||
|
data[ct.default_float_array_field_name] = float_field_value
|
||||||
|
collection_w.insert(data)
|
||||||
|
collection_w.create_index(ct.default_float_vec_field_name, {})
|
||||||
|
|
||||||
|
# 3. search with array_contains_any with float and int target
|
||||||
|
collection_w.load()
|
||||||
|
expression = f"{expr_prefix}({ct.default_float_array_field_name}, [0.5, 0.6, 1, 2])"
|
||||||
|
res = collection_w.search(vectors[:default_nq], default_search_field, {},
|
||||||
|
limit=ct.default_nb, expr=expression)[0]
|
||||||
|
exp_ids = cf.assert_json_contains(expression, float_field_value)
|
||||||
|
assert set(res[0].ids) == set(exp_ids)
|
||||||
|
|
||||||
@pytest.mark.tags(CaseLabel.L2)
|
@pytest.mark.tags(CaseLabel.L2)
|
||||||
@pytest.mark.parametrize("expr_prefix", ["array_contains_all", "ARRAY_CONTAINS_ALL",
|
@pytest.mark.parametrize("expr_prefix", ["array_contains_all", "ARRAY_CONTAINS_ALL",
|
||||||
"array_contains_any", "ARRAY_CONTAINS_ANY"])
|
"array_contains_any", "ARRAY_CONTAINS_ANY"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user