From 532f10f3436b4d19e3a8365cca8d35ad32c30188 Mon Sep 17 00:00:00 2001 From: Bingyi Sun Date: Thu, 5 Jun 2025 20:26:33 +0800 Subject: [PATCH] enhance: Support cast function for json index (#42504) issue: #41948 pr: #41949 --------- Signed-off-by: sunby --- internal/core/src/common/Consts.h | 4 +- internal/core/src/common/Json.h | 12 ++ internal/core/src/common/JsonCastFunction.cpp | 124 ++++++++++++++++++ internal/core/src/common/JsonCastFunction.h | 79 +++++++++++ internal/core/src/index/IndexFactory.cpp | 21 ++- internal/core/src/index/IndexFactory.h | 12 +- internal/core/src/index/IndexInfo.h | 1 + internal/core/src/index/JsonInvertedIndex.cpp | 20 ++- internal/core/src/index/JsonInvertedIndex.h | 10 +- .../src/indexbuilder/ScalarIndexCreator.cpp | 4 + internal/core/unittest/test_json_index.cpp | 81 ++++++++++++ internal/proxy/task_index.go | 3 + .../util/indexparamcheck/inverted_checker.go | 13 ++ pkg/common/common.go | 5 +- 14 files changed, 368 insertions(+), 21 deletions(-) create mode 100644 internal/core/src/common/JsonCastFunction.cpp create mode 100644 internal/core/src/common/JsonCastFunction.h diff --git a/internal/core/src/common/Consts.h b/internal/core/src/common/Consts.h index c3868169d1..13244ad111 100644 --- a/internal/core/src/common/Consts.h +++ b/internal/core/src/common/Consts.h @@ -82,7 +82,7 @@ const size_t MARISA_NULL_KEY_ID = -1; const bool DEFAULT_JSON_INDEX_ENABLED = true; const std::string JSON_CAST_TYPE = "json_cast_type"; const std::string JSON_PATH = "json_path"; - +const std::string JSON_CAST_FUNCTION = "json_cast_function"; const bool DEFAULT_OPTIMIZE_EXPR_ENABLED = true; const bool DEFAULT_GROWING_JSON_KEY_STATS_ENABLED = false; const int64_t DEFAULT_JSON_KEY_STATS_COMMIT_INTERVAL = 200; @@ -100,3 +100,5 @@ const std::string DATA_TYPE_KEY = "data_type"; // storage version const int64_t STORAGE_V1 = 1; const int64_t STORAGE_V2 = 2; + +const std::string UNKNOW_CAST_FUNCTION_NAME = "unknown"; diff --git a/internal/core/src/common/Json.h b/internal/core/src/common/Json.h index cdcc544194..e364bf2fc3 100644 --- a/internal/core/src/common/Json.h +++ b/internal/core/src/common/Json.h @@ -221,6 +221,18 @@ class Json { return pointer; } + auto + type(const std::string& pointer) const { + return pointer.empty() ? doc().type() + : doc().at_pointer(pointer).type(); + } + + auto + get_number_type(const std::string& pointer) const { + return pointer.empty() ? doc().get_number_type() + : doc().at_pointer(pointer).get_number_type(); + } + template value_result at(std::string_view pointer) const { diff --git a/internal/core/src/common/JsonCastFunction.cpp b/internal/core/src/common/JsonCastFunction.cpp new file mode 100644 index 0000000000..5054f76b5b --- /dev/null +++ b/internal/core/src/common/JsonCastFunction.cpp @@ -0,0 +1,124 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "common/JsonCastFunction.h" +#include +#include "common/EasyAssert.h" + +namespace milvus { + +const std::unordered_map + JsonCastFunction::predefined_cast_functions_ = { + {"STRING_TO_DOUBLE", JsonCastFunction(Type::kString2Double)}, +}; + +JsonCastFunction +JsonCastFunction::FromString(const std::string& str) { + auto it = predefined_cast_functions_.find(str); + if (it != predefined_cast_functions_.end()) { + return it->second; + } + return JsonCastFunction(Type::kUnknown); +} + +template <> +std::optional +JsonCastFunction::cast(const std::string& t) const { + try { + return std::stod(t); + } catch (const std::exception&) { + return std::nullopt; + } +} + +template <> +std::optional +JsonCastFunction::cast(const int64_t& t) const { + return static_cast(t); +} + +template <> +std::optional +JsonCastFunction::cast(const double& t) const { + return t; +} + +template <> +std::optional +JsonCastFunction::cast(const bool& t) const { + return std::nullopt; +} + +template +std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer) { + AssertInfo(cast_function.match(), "Type mismatch"); + + auto json_type = json.type(pointer); + std::optional res; + + switch (json_type.value()) { + case simdjson::ondemand::json_type::string: { + auto json_value = json.at(pointer); + res = cast_function.cast( + std::string(json_value.value())); + break; + } + + case simdjson::ondemand::json_type::number: { + if (json.get_number_type(pointer) == + simdjson::ondemand::number_type::floating_point_number) { + auto json_value = json.at(pointer); + res = cast_function.cast(json_value.value()); + } else { + auto json_value = json.at(pointer); + res = cast_function.cast(json_value.value()); + } + break; + } + + case simdjson::ondemand::json_type::boolean: { + auto json_value = json.at(pointer); + res = cast_function.cast(json_value.value()); + break; + } + + default: + break; + } + + return res; +} + +template std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +template std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +template std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +template std::optional +JsonCastFunction::CastJsonValue( + const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +} // namespace milvus diff --git a/internal/core/src/common/JsonCastFunction.h b/internal/core/src/common/JsonCastFunction.h new file mode 100644 index 0000000000..fad2facbae --- /dev/null +++ b/internal/core/src/common/JsonCastFunction.h @@ -0,0 +1,79 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include "common/EasyAssert.h" +#include "common/Json.h" +#include "common/JsonCastType.h" + +namespace milvus { + +class JsonCastFunction { + public: + enum class Type { + kUnknown, + kString2Double, + }; + + template + std::optional + cast(const F& t) const { + PanicInfo(Unsupported, "Not implemented"); + } + + template + bool + match() const { + switch (cast_function_type_) { + case Type::kString2Double: + return std::is_same_v; + } + return false; + } + + static JsonCastFunction + FromString(const std::string& str); + + template + static std::optional + CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + + private: + JsonCastFunction(Type type) : cast_function_type_(type) { + } + Type cast_function_type_; + + static const std::unordered_map + predefined_cast_functions_; +}; +template <> +std::optional +JsonCastFunction::cast(const std::string& t) const; + +template <> +std::optional +JsonCastFunction::cast(const int64_t& t) const; + +template <> +std::optional +JsonCastFunction::cast(const double& t) const; + +template <> +std::optional +JsonCastFunction::cast(const bool& t) const; + +} // namespace milvus diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index fdfe33dce6..bfdfa34a9e 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -380,20 +380,30 @@ IndexFactory::CreateJsonIndex( IndexType index_type, JsonCastType cast_dtype, const std::string& nested_path, - const storage::FileManagerContext& file_manager_context) { + const storage::FileManagerContext& file_manager_context, + const std::string& json_cast_function) { AssertInfo(index_type == INVERTED_INDEX_TYPE, "Invalid index type for json index"); switch (cast_dtype.element_type()) { case JsonCastType::DataType::BOOL: return std::make_unique>( - cast_dtype, nested_path, file_manager_context); + cast_dtype, + nested_path, + file_manager_context, + JsonCastFunction::FromString(json_cast_function)); case JsonCastType::DataType::DOUBLE: return std::make_unique>( - cast_dtype, nested_path, file_manager_context); + cast_dtype, + nested_path, + file_manager_context, + JsonCastFunction::FromString(json_cast_function)); case JsonCastType::DataType::VARCHAR: return std::make_unique>( - cast_dtype, nested_path, file_manager_context); + cast_dtype, + nested_path, + file_manager_context, + JsonCastFunction::FromString(json_cast_function)); default: PanicInfo(DataTypeInvalid, "Invalid data type:{}", cast_dtype); } @@ -424,7 +434,8 @@ IndexFactory::CreateScalarIndex( return CreateJsonIndex(create_index_info.index_type, create_index_info.json_cast_type, create_index_info.json_path, - file_manager_context); + file_manager_context, + create_index_info.json_cast_function); } default: PanicInfo(DataTypeInvalid, "Invalid data type:{}", data_type); diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index 27bd62883e..42cf952fc6 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -108,11 +108,13 @@ class IndexFactory { storage::FileManagerContext()); IndexBasePtr - CreateJsonIndex(IndexType index_type, - JsonCastType cast_dtype, - const std::string& nested_path, - const storage::FileManagerContext& file_manager_context = - storage::FileManagerContext()); + CreateJsonIndex( + IndexType index_type, + JsonCastType cast_dtype, + const std::string& nested_path, + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext(), + const std::string& json_cast_function = UNKNOW_CAST_FUNCTION_NAME); IndexBasePtr CreateScalarIndex(const CreateIndexInfo& create_index_info, diff --git a/internal/core/src/index/IndexInfo.h b/internal/core/src/index/IndexInfo.h index df957846ce..2b4b6fca63 100644 --- a/internal/core/src/index/IndexInfo.h +++ b/internal/core/src/index/IndexInfo.h @@ -30,6 +30,7 @@ struct CreateIndexInfo { int32_t scalar_index_engine_version; JsonCastType json_cast_type{JsonCastType::UNKNOWN}; std::string json_path; + std::string json_cast_function; }; } // namespace milvus::index diff --git a/internal/core/src/index/JsonInvertedIndex.cpp b/internal/core/src/index/JsonInvertedIndex.cpp index 569819b443..50b1868bd6 100644 --- a/internal/core/src/index/JsonInvertedIndex.cpp +++ b/internal/core/src/index/JsonInvertedIndex.cpp @@ -145,13 +145,21 @@ JsonInvertedIndex::build_index_for_json( } } } else { - value_result res = - json_column->at(nested_path_); - if (res.error() != simdjson::SUCCESS) { - error_recorder_.Record( - *json_column, nested_path_, res.error()); + if (cast_function_.match()) { + auto res = JsonCastFunction::CastJsonValue( + cast_function_, *json_column, nested_path_); + if (res.has_value()) { + values.push_back(res.value()); + } } else { - values.push_back(static_cast(res.value())); + value_result res = + json_column->at(nested_path_); + if (res.error() != simdjson::SUCCESS) { + error_recorder_.Record( + *json_column, nested_path_, res.error()); + } else { + values.push_back(static_cast(res.value())); + } } } this->wrapper_->template add_multi_data( diff --git a/internal/core/src/index/JsonInvertedIndex.h b/internal/core/src/index/JsonInvertedIndex.h index bde76a51f9..f721bae87e 100644 --- a/internal/core/src/index/JsonInvertedIndex.h +++ b/internal/core/src/index/JsonInvertedIndex.h @@ -12,6 +12,7 @@ #pragma once #include #include "common/FieldDataInterface.h" +#include "common/JsonCastFunction.h" #include "common/JsonCastType.h" #include "index/InvertedIndexTantivy.h" #include "index/ScalarIndex.h" @@ -69,8 +70,12 @@ class JsonInvertedIndex : public index::InvertedIndexTantivy { public: JsonInvertedIndex(const JsonCastType& cast_type, const std::string& nested_path, - const storage::FileManagerContext& ctx) - : nested_path_(nested_path), cast_type_(cast_type) { + const storage::FileManagerContext& ctx, + const JsonCastFunction& cast_function = + JsonCastFunction::FromString("unknown")) + : nested_path_(nested_path), + cast_type_(cast_type), + cast_function_(cast_function) { this->schema_ = ctx.fieldDataMeta.field_schema; this->mem_file_manager_ = std::make_shared(ctx); @@ -124,6 +129,7 @@ class JsonInvertedIndex : public index::InvertedIndexTantivy { std::string nested_path_; JsonInvertedIndexParseErrorRecorder error_recorder_; JsonCastType cast_type_; + JsonCastFunction cast_function_; }; } // namespace milvus::index diff --git a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp index a7934f4d73..6f19eba85b 100644 --- a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp +++ b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp @@ -47,6 +47,10 @@ ScalarIndexCreator::ScalarIndexCreator( index_info.json_cast_type = milvus::JsonCastType::FromString( config.at(JSON_CAST_TYPE).get()); index_info.json_path = config.at(JSON_PATH).get(); + if (config.contains(JSON_CAST_FUNCTION)) { + index_info.json_cast_function = + config.at(JSON_CAST_FUNCTION).get(); + } } index_ = index::IndexFactory::GetInstance().CreateIndex( index_info, file_manager_context); diff --git a/internal/core/unittest/test_json_index.cpp b/internal/core/unittest/test_json_index.cpp index 36ddc6aee3..3fa5bcb20d 100644 --- a/internal/core/unittest/test_json_index.cpp +++ b/internal/core/unittest/test_json_index.cpp @@ -185,3 +185,84 @@ TEST(JsonIndexTest, TestJsonContains) { } } } + +TEST(JsonIndexTest, TestJsonCast) { + std::vector json_raw_data = { + R"(1)", + R"({"a": 1.0})", + R"({"a": 1})", + R"({"a": "1.0"})", + R"({"a": true})", + R"({"a": [1, 2, 3]})", + R"({"a": {"b": 1}})", + }; + + auto json_path = "/a"; + auto schema = std::make_shared(); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + + auto file_manager_ctx = storage::FileManagerContext(); + file_manager_ctx.fieldDataMeta.field_schema.set_data_type( + milvus::proto::schema::JSON); + file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); + file_manager_ctx.fieldDataMeta.field_id = json_fid.get(); + + auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( + index::INVERTED_INDEX_TYPE, + JsonCastType::FromString("DOUBLE"), + json_path, + file_manager_ctx, + "STRING_TO_DOUBLE"); + auto json_index = std::unique_ptr>( + static_cast*>(inv_index.release())); + + std::vector jsons; + for (auto& json : json_raw_data) { + jsons.push_back(milvus::Json(simdjson::padded_string(json))); + } + + auto json_field = + std::make_shared>(DataType::JSON, false); + json_field->add_json_data(jsons); + json_index->BuildWithFieldData({json_field}); + json_index->finish(); + json_index->create_reader(); + + auto segment = segcore::CreateSealedSegment(schema); + segcore::LoadIndexInfo load_index_info; + load_index_info.field_id = json_fid.get(); + load_index_info.field_type = DataType::JSON; + load_index_info.index = std::move(json_index); + load_index_info.index_params = {{JSON_PATH, json_path}}; + segment->LoadIndex(load_index_info); + + auto field_data_info = FieldDataInfo{json_fid.get(), + json_raw_data.size(), + std::vector{json_field}}; + segment->LoadFieldData(json_fid, field_data_info); + + std::vector>> + test_cases; + + proto::plan::GenericValue value; + value.set_int64_val(1); + test_cases.push_back(std::make_tuple(value, std::vector{1, 2, 3})); + for (auto& test_case : test_cases) { + auto expr = std::make_shared( + expr::ColumnInfo(json_fid, DataType::JSON, {"a"}, true), + proto::plan::OpType::Equal, + value); + + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + + auto result = query::ExecuteQueryExpr( + plan, segment.get(), json_raw_data.size(), MAX_TIMESTAMP); + + auto expect_result = std::get<1>(test_case); + EXPECT_EQ(result.count(), expect_result.size()); + for (auto& id : expect_result) { + EXPECT_TRUE(result[id]); + } + } +} diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index c6c21cc890..398f5ce360 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -196,6 +196,9 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error { if jsonCastType, exist := indexParamsMap[common.JSONCastTypeKey]; exist { indexParamsMap[common.JSONCastTypeKey] = strings.ToUpper(strings.TrimSpace(jsonCastType)) } + if jsonCastFunction, exist := indexParamsMap[common.JSONCastFunctionKey]; exist { + indexParamsMap[common.JSONCastFunctionKey] = strings.ToUpper(strings.TrimSpace(jsonCastFunction)) + } if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil { return err diff --git a/internal/util/indexparamcheck/inverted_checker.go b/internal/util/indexparamcheck/inverted_checker.go index a83cb9bb1d..9b1acd7776 100644 --- a/internal/util/indexparamcheck/inverted_checker.go +++ b/internal/util/indexparamcheck/inverted_checker.go @@ -18,6 +18,8 @@ type INVERTEDChecker struct { var validJSONCastTypes = []string{"BOOL", "DOUBLE", "VARCHAR"} +var validJSONCastFunctions = []string{"STRING_TO_DOUBLE"} + func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { // check json index params isJSONIndex := typeutil.IsJSONType(dataType) @@ -30,6 +32,17 @@ func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[stri if !lo.Contains(validJSONCastTypes, castType) { return merr.WrapErrParameterInvalidMsg("json_cast_type %v is not supported", castType) } + castFunction, exist := params[common.JSONCastFunctionKey] + if exist { + switch castFunction { + case "STRING_TO_DOUBLE": + if castType != "DOUBLE" { + return merr.WrapErrParameterInvalidMsg("json_cast_function %v is not supported for json_cast_type %v", castFunction, castType) + } + default: + return merr.WrapErrParameterInvalidMsg("json_cast_function %v is not supported", castFunction) + } + } } return c.scalarIndexChecker.CheckTrain(dataType, params) } diff --git a/pkg/common/common.go b/pkg/common/common.go index 73db81416b..9a4b536f6c 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -149,8 +149,9 @@ const ( ConsistencyLevel = "consistency_level" HintsKey = "hints" - JSONCastTypeKey = "json_cast_type" - JSONPathKey = "json_path" + JSONCastTypeKey = "json_cast_type" + JSONPathKey = "json_path" + JSONCastFunctionKey = "json_cast_function" ) // Doc-in-doc-out