enhance: support binary range expression for json path index (#41317)

pr: #41025 
issue: #35528

---------

Signed-off-by: sunby <sunbingyi1992@gmail.com>
This commit is contained in:
Bingyi Sun 2025-04-24 20:04:39 +08:00 committed by GitHub
parent 8b3353cdab
commit 4ac57f1217
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 247 additions and 18 deletions

View File

@ -68,24 +68,60 @@ PhyBinaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) {
}
case DataType::JSON: {
auto value_type = expr_->lower_val_.val_case();
switch (value_type) {
case proto::plan::GenericValue::ValCase::kInt64Val: {
result = ExecRangeVisitorImplForJson<int64_t>(context);
break;
if (is_index_mode_ && !has_offset_input_) {
switch (value_type) {
case proto::plan::GenericValue::ValCase::kInt64Val: {
proto::plan::GenericValue double_lower_val;
double_lower_val.set_float_val(
static_cast<double>(expr_->lower_val_.int64_val()));
proto::plan::GenericValue double_upper_val;
double_upper_val.set_float_val(
static_cast<double>(expr_->upper_val_.int64_val()));
lower_arg_.SetValue<double>(double_lower_val);
upper_arg_.SetValue<double>(double_upper_val);
arg_inited_ = true;
result = ExecRangeVisitorImplForIndex<double>();
break;
}
case proto::plan::GenericValue::ValCase::kFloatVal: {
result = ExecRangeVisitorImplForIndex<double>();
break;
}
case proto::plan::GenericValue::ValCase::kStringVal: {
result =
ExecRangeVisitorImplForJson<std::string>(context);
break;
}
default: {
PanicInfo(DataTypeInvalid,
fmt::format(
"unsupported value type {} in expression",
value_type));
}
}
case proto::plan::GenericValue::ValCase::kFloatVal: {
result = ExecRangeVisitorImplForJson<double>(context);
break;
}
case proto::plan::GenericValue::ValCase::kStringVal: {
result = ExecRangeVisitorImplForJson<std::string>(context);
break;
}
default: {
PanicInfo(
DataTypeInvalid,
fmt::format("unsupported value type {} in expression",
value_type));
} else {
switch (value_type) {
case proto::plan::GenericValue::ValCase::kInt64Val: {
result = ExecRangeVisitorImplForJson<int64_t>(context);
break;
}
case proto::plan::GenericValue::ValCase::kFloatVal: {
result = ExecRangeVisitorImplForJson<double>(context);
break;
}
case proto::plan::GenericValue::ValCase::kStringVal: {
result =
ExecRangeVisitorImplForJson<std::string>(context);
break;
}
default: {
PanicInfo(DataTypeInvalid,
fmt::format(
"unsupported value type {} in expression",
value_type));
}
}
}
break;

View File

@ -252,7 +252,7 @@ class PhyBinaryRangeFilterExpr : public SegmentExpr {
segment,
expr->column_.field_id_,
expr->column_.nested_path_,
DataType::NONE,
FromValCase(expr->lower_val_.val_case()),
active_count,
batch_size,
consistency_level),

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <algorithm>
#include <any>
#include <boost/format.hpp>
#include <fstream>
#include <gtest/gtest.h>
@ -27,8 +28,10 @@
#include "common/FieldDataInterface.h"
#include "common/Json.h"
#include "common/JsonCastType.h"
#include "common/LoadInfo.h"
#include "common/Types.h"
#include "gtest/gtest.h"
#include "index/Meta.h"
#include "index/JsonInvertedIndex.h"
#include "knowhere/comp/index_param.h"
@ -16880,3 +16883,193 @@ TEST_P(JsonIndexExistsTest, TestExistsExpr) {
EXPECT_TRUE(result == expect_res);
}
}
class JsonIndexBinaryExprTest : public testing::TestWithParam<JsonCastType> {};
INSTANTIATE_TEST_SUITE_P(JsonIndexBinaryExprTestParams,
JsonIndexBinaryExprTest,
testing::Values(JsonCastType::DOUBLE,
JsonCastType::VARCHAR));
TEST_P(JsonIndexBinaryExprTest, TestBinaryRangeExpr) {
auto json_strs = std::vector<std::string>{
R"({"a": 1})",
R"({"a": 2})",
R"({"a": 3})",
R"({"a": 4})",
R"({"a": 1.0})",
R"({"a": 2.0})",
R"({"a": 3.0})",
R"({"a": 4.0})",
R"({"a": "1"})",
R"({"a": "2"})",
R"({"a": "3"})",
R"({"a": "4"})",
R"({"a": null})",
R"({"a": true})",
R"({"a": false})",
};
auto test_cases = std::vector<std::tuple<std::any,
std::any,
/*lower inclusive*/ bool,
/*upper inclusive*/ bool,
uint32_t>>{
// Exact match for integer 1 (matches both int 1 and float 1.0)
{std::make_any<int64_t>(1),
std::make_any<int64_t>(1),
true,
true,
0b1000'1000'0000'000},
// Range [1, 3] inclusive (matches int 1,2,3 and float 1.0,2.0,3.0)
{std::make_any<int64_t>(1),
std::make_any<int64_t>(3),
true,
true,
0b1110'1110'0000'000},
// Range (1, 3) exclusive (matches only int 2 and float 2.0)
{std::make_any<int64_t>(1),
std::make_any<int64_t>(3),
false,
false,
0b0100'0100'0000'000},
// Range [1, 3) left inclusive, right exclusive (matches int 1,2 and float 1.0,2.0)
{std::make_any<int64_t>(1),
std::make_any<int64_t>(3),
true,
false,
0b1100'1100'0000'000},
// Range (1, 3] left exclusive, right inclusive (matches int 2,3 and float 2.0,3.0)
{std::make_any<int64_t>(1),
std::make_any<int64_t>(3),
false,
true,
0b0110'0110'0000'000},
// Float range test [1.0, 3.0] (matches int 1,2,3 and float 1.0,2.0,3.0)
{std::make_any<double>(1.0),
std::make_any<double>(3.0),
true,
true,
0b1110'1110'0000'000},
// String range test ["1", "3"] (matches string "1","2","3")
{std::make_any<std::string>("1"),
std::make_any<std::string>("3"),
true,
true,
0b0000'0000'1110'000},
// Range that should match nothing
{std::make_any<int64_t>(10),
std::make_any<int64_t>(20),
true,
true,
0b0000'0000'0000'000},
// Range [2, 4] inclusive (matches int 2,3,4 and float 2.0,3.0,4.0)
{std::make_any<int64_t>(2),
std::make_any<int64_t>(4),
true,
true,
0b0111'0111'0000'000},
// Mixed type range test - int to float [1, 3.0]
// {std::make_any<int64_t>(1),
// std::make_any<double>(3.0),
// true,
// true,
// 0b1110'1110'0000'000},
};
auto schema = std::make_shared<Schema>();
auto vec_fid = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2);
auto i64_fid = schema->AddDebugField("age64", DataType::INT64);
auto json_fid = schema->AddDebugField("json", DataType::JSON);
schema->set_primary_field_id(i64_fid);
auto seg = CreateSealedSegment(schema);
segcore::LoadIndexInfo load_index_info;
auto file_manager_ctx = storage::FileManagerContext();
file_manager_ctx.fieldDataMeta.field_schema.set_data_type(
milvus::proto::schema::JSON);
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
index::INVERTED_INDEX_TYPE, GetParam(), "/a", file_manager_ctx);
using json_index_type = index::JsonInvertedIndex<double>;
auto json_index = std::unique_ptr<json_index_type>(
static_cast<json_index_type*>(inv_index.release()));
auto json_field =
std::make_shared<FieldData<milvus::Json>>(DataType::JSON, false);
std::vector<milvus::Json> jsons;
for (auto& json : json_strs) {
jsons.push_back(milvus::Json(simdjson::padded_string(json)));
}
json_field->add_json_data(jsons);
json_index->BuildWithFieldData({json_field});
json_index->finish();
json_index->create_reader();
load_index_info.field_id = json_fid.get();
load_index_info.field_type = DataType::JSON;
load_index_info.index = std::move(json_index);
load_index_info.index_params = {{JSON_PATH, "/a"}};
seg->LoadIndex(load_index_info);
auto json_field_data_info =
FieldDataInfo(json_fid.get(), json_strs.size(), {json_field});
seg->LoadFieldData(json_fid, json_field_data_info);
for (auto& [lower, upper, lower_inclusive, upper_inclusive, result] :
test_cases) {
proto::plan::GenericValue lower_val;
proto::plan::GenericValue upper_val;
if (lower.type() == typeid(int64_t)) {
lower_val.set_int64_val(std::any_cast<int64_t>(lower));
} else if (lower.type() == typeid(double)) {
lower_val.set_float_val(std::any_cast<double>(lower));
} else if (lower.type() == typeid(std::string)) {
lower_val.set_string_val(std::any_cast<std::string>(lower));
}
if (upper.type() == typeid(int64_t)) {
upper_val.set_int64_val(std::any_cast<int64_t>(upper));
} else if (upper.type() == typeid(double)) {
upper_val.set_float_val(std::any_cast<double>(upper));
} else if (upper.type() == typeid(std::string)) {
upper_val.set_string_val(std::any_cast<std::string>(upper));
}
BitsetType expect_result;
expect_result.resize(json_strs.size());
for (int i = json_strs.size() - 1; result > 0; i--) {
expect_result.set(i, (result & 0x1) != 0);
result >>= 1;
}
auto binary_expr = std::make_shared<expr::BinaryRangeFilterExpr>(
expr::ColumnInfo(json_fid, DataType::JSON, {"a"}),
lower_val,
upper_val,
lower_inclusive,
upper_inclusive);
auto plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
binary_expr);
auto res =
ExecuteQueryExpr(plan, seg.get(), json_strs.size(), MAX_TIMESTAMP);
EXPECT_TRUE(res == expect_result);
}
}