diff --git a/internal/core/src/common/Consts.h b/internal/core/src/common/Consts.h index 8e4d96c9e2..a138f85b09 100644 --- a/internal/core/src/common/Consts.h +++ b/internal/core/src/common/Consts.h @@ -81,6 +81,7 @@ const size_t MARISA_NULL_KEY_ID = -1; const std::string JSON_CAST_TYPE = "json_cast_type"; const std::string JSON_PATH = "json_path"; +const std::string JSON_CAST_FUNCTION = "json_cast_function"; const bool DEFAULT_OPTIMIZE_EXPR_ENABLED = true; const int64_t DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT = 150; const int64_t DEFAULT_JSON_INDEX_MEMORY_BUDGET = 16777216; // bytes, 16MB @@ -99,4 +100,6 @@ const std::string INDEX_NUM_ROWS_KEY = "index_num_rows"; // storage version const int64_t STORAGE_V1 = 1; -const int64_t STORAGE_V2 = 2; \ No newline at end of file +const int64_t STORAGE_V2 = 2; + +const std::string UNKNOW_CAST_FUNCTION_NAME = "unknown"; \ No newline at end of file diff --git a/internal/core/src/common/Json.h b/internal/core/src/common/Json.h index cdcc544194..bfb147a527 100644 --- a/internal/core/src/common/Json.h +++ b/internal/core/src/common/Json.h @@ -221,6 +221,17 @@ class Json { return pointer; } + auto + type(const std::string& pointer) const { + return pointer.empty() ? doc().type() + : doc().at_pointer(pointer).type(); + } + + auto get_number_type(const std::string& pointer) const { + return pointer.empty() ? doc().get_number_type() + : doc().at_pointer(pointer).get_number_type(); + } + template value_result at(std::string_view pointer) const { diff --git a/internal/core/src/common/JsonCastFunction.cpp b/internal/core/src/common/JsonCastFunction.cpp new file mode 100644 index 0000000000..69e2abbd98 --- /dev/null +++ b/internal/core/src/common/JsonCastFunction.cpp @@ -0,0 +1,113 @@ +#include "common/JsonCastFunction.h" +#include +#include "common/EasyAssert.h" + +namespace milvus { + +const std::unordered_map + JsonCastFunction::predefined_cast_functions_ = { + {"STRING_TO_DOUBLE", JsonCastFunction(Type::kString2Double)}, +}; + +JsonCastFunction +JsonCastFunction::FromString(const std::string& str) { + auto it = predefined_cast_functions_.find(str); + if (it != predefined_cast_functions_.end()) { + return it->second; + } + return JsonCastFunction(Type::kUnknown); +} + +template <> +std::optional +JsonCastFunction::cast(const std::string& t) const { + try { + return std::stod(t); + } catch (const std::exception&) { + return std::nullopt; + } +} + +template <> +std::optional +JsonCastFunction::cast(const int64_t& t) const { + return static_cast(t); +} + +template <> +std::optional +JsonCastFunction::cast(const double& t) const { + return t; +} + +template <> +std::optional +JsonCastFunction::cast(const bool& t) const { + return std::nullopt; +} + +template +std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer) { + AssertInfo(cast_function.match(), "Type mismatch"); + + auto json_type = json.type(pointer); + std::optional res; + + switch (json_type.value()) { + case simdjson::ondemand::json_type::string: { + auto json_value = json.at(pointer); + res = cast_function.cast( + std::string(json_value.value())); + break; + } + + case simdjson::ondemand::json_type::number: { + if (json.get_number_type(pointer) == + simdjson::ondemand::number_type::floating_point_number) { + auto json_value = json.at(pointer); + res = cast_function.cast(json_value.value()); + } else { + auto json_value = json.at(pointer); + res = cast_function.cast(json_value.value()); + } + break; + } + + case simdjson::ondemand::json_type::boolean: { + auto json_value = json.at(pointer); + res = cast_function.cast(json_value.value()); + break; + } + + default: + break; + } + + return res; +} + +template std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +template std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +template std::optional +JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +template std::optional +JsonCastFunction::CastJsonValue( + const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + +} // namespace milvus diff --git a/internal/core/src/common/JsonCastFunction.h b/internal/core/src/common/JsonCastFunction.h new file mode 100644 index 0000000000..61c8e4a196 --- /dev/null +++ b/internal/core/src/common/JsonCastFunction.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include "common/EasyAssert.h" +#include "common/Json.h" +#include "common/JsonCastType.h" + +namespace milvus { + +class JsonCastFunction { + public: + enum class Type { + kUnknown, + kString2Double, + }; + + template + std::optional + cast(const F& t) const { + PanicInfo(Unsupported, "Not implemented"); + } + + template + bool + match() const { + switch (cast_function_type_) { + case Type::kString2Double: + return std::is_same_v; + } + return false; + } + + static JsonCastFunction + FromString(const std::string& str); + + template + static std::optional + CastJsonValue(const JsonCastFunction& cast_function, + const Json& json, + const std::string& pointer); + + private: + JsonCastFunction(Type type) : cast_function_type_(type) { + } + Type cast_function_type_; + + static const std::unordered_map + predefined_cast_functions_; +}; +template <> +std::optional +JsonCastFunction::cast(const std::string& t) const; + +template <> +std::optional +JsonCastFunction::cast(const int64_t& t) const; + +template <> +std::optional +JsonCastFunction::cast(const double& t) const; + +template <> +std::optional +JsonCastFunction::cast(const bool& t) const; + +} // namespace milvus diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index 75b3a8eff9..7ce06bc9ee 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -384,20 +384,30 @@ IndexFactory::CreateJsonIndex( IndexType index_type, JsonCastType cast_dtype, const std::string& nested_path, - const storage::FileManagerContext& file_manager_context) { + const storage::FileManagerContext& file_manager_context, + const std::string& json_cast_function) { AssertInfo(index_type == INVERTED_INDEX_TYPE, "Invalid index type for json index"); switch (cast_dtype.element_type()) { case JsonCastType::DataType::BOOL: return std::make_unique>( - cast_dtype, nested_path, file_manager_context); + cast_dtype, + nested_path, + file_manager_context, + JsonCastFunction::FromString(json_cast_function)); case JsonCastType::DataType::DOUBLE: return std::make_unique>( - cast_dtype, nested_path, file_manager_context); + cast_dtype, + nested_path, + file_manager_context, + JsonCastFunction::FromString(json_cast_function)); case JsonCastType::DataType::VARCHAR: return std::make_unique>( - cast_dtype, nested_path, file_manager_context); + cast_dtype, + nested_path, + file_manager_context, + JsonCastFunction::FromString(json_cast_function)); default: PanicInfo(DataTypeInvalid, "Invalid data type:{}", cast_dtype); } @@ -428,7 +438,8 @@ IndexFactory::CreateScalarIndex( return CreateJsonIndex(create_index_info.index_type, create_index_info.json_cast_type, create_index_info.json_path, - file_manager_context); + file_manager_context, + create_index_info.json_cast_function); } default: PanicInfo(DataTypeInvalid, "Invalid data type:{}", data_type); diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index 693ad0d35a..9d3058efb1 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -113,7 +113,9 @@ class IndexFactory { JsonCastType cast_dtype, const std::string& nested_path, const storage::FileManagerContext& file_manager_context = - storage::FileManagerContext()); + storage::FileManagerContext(), + const std::string& json_cast_function = + UNKNOW_CAST_FUNCTION_NAME); IndexBasePtr CreateScalarIndex(const CreateIndexInfo& create_index_info, diff --git a/internal/core/src/index/IndexInfo.h b/internal/core/src/index/IndexInfo.h index 486bd2f0f1..cf0ceae96d 100644 --- a/internal/core/src/index/IndexInfo.h +++ b/internal/core/src/index/IndexInfo.h @@ -31,6 +31,7 @@ struct CreateIndexInfo { uint32_t tantivy_index_version{7}; JsonCastType json_cast_type{JsonCastType::UNKNOWN}; std::string json_path; + std::string json_cast_function; }; } // namespace milvus::index diff --git a/internal/core/src/index/JsonInvertedIndex.cpp b/internal/core/src/index/JsonInvertedIndex.cpp index 76b3206191..db1ba17fdc 100644 --- a/internal/core/src/index/JsonInvertedIndex.cpp +++ b/internal/core/src/index/JsonInvertedIndex.cpp @@ -148,13 +148,21 @@ JsonInvertedIndex::build_index_for_json( } } } else { - value_result res = - json_column->at(nested_path_); - if (res.error() != simdjson::SUCCESS) { - error_recorder_.Record( - *json_column, nested_path_, res.error()); + if (cast_function_.match()) { + auto res = JsonCastFunction::CastJsonValue( + cast_function_, *json_column, nested_path_); + if (res.has_value()) { + values.push_back(res.value()); + } } else { - values.push_back(static_cast(res.value())); + value_result res = + json_column->at(nested_path_); + if (res.error() != simdjson::SUCCESS) { + error_recorder_.Record( + *json_column, nested_path_, res.error()); + } else { + values.push_back(static_cast(res.value())); + } } } this->wrapper_->template add_array_data( diff --git a/internal/core/src/index/JsonInvertedIndex.h b/internal/core/src/index/JsonInvertedIndex.h index 598e215589..7e6df24952 100644 --- a/internal/core/src/index/JsonInvertedIndex.h +++ b/internal/core/src/index/JsonInvertedIndex.h @@ -12,6 +12,7 @@ #pragma once #include #include "common/FieldDataInterface.h" +#include "common/JsonCastFunction.h" #include "common/JsonCastType.h" #include "index/InvertedIndexTantivy.h" #include "index/ScalarIndex.h" @@ -69,8 +70,12 @@ class JsonInvertedIndex : public index::InvertedIndexTantivy { public: JsonInvertedIndex(const JsonCastType& cast_type, const std::string& nested_path, - const storage::FileManagerContext& ctx) - : nested_path_(nested_path), cast_type_(cast_type) { + const storage::FileManagerContext& ctx, + const JsonCastFunction& cast_function = + JsonCastFunction::FromString("unknown")) + : nested_path_(nested_path), + cast_type_(cast_type), + cast_function_(cast_function) { this->schema_ = ctx.fieldDataMeta.field_schema; this->mem_file_manager_ = std::make_shared(ctx); @@ -127,6 +132,7 @@ class JsonInvertedIndex : public index::InvertedIndexTantivy { std::string nested_path_; JsonInvertedIndexParseErrorRecorder error_recorder_; JsonCastType cast_type_; + JsonCastFunction cast_function_; }; } // namespace milvus::index diff --git a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp index 789e3443d1..1c44b6aa0d 100644 --- a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp +++ b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp @@ -52,6 +52,10 @@ ScalarIndexCreator::ScalarIndexCreator( index_info.json_cast_type = milvus::JsonCastType::FromString( config.at(JSON_CAST_TYPE).get()); index_info.json_path = config.at(JSON_PATH).get(); + if (config.contains(JSON_CAST_FUNCTION)) { + index_info.json_cast_function = + config.at(JSON_CAST_FUNCTION).get(); + } } index_ = index::IndexFactory::GetInstance().CreateIndex( index_info, file_manager_context); diff --git a/internal/core/unittest/test_json_index.cpp b/internal/core/unittest/test_json_index.cpp index ccca221522..b242ed8237 100644 --- a/internal/core/unittest/test_json_index.cpp +++ b/internal/core/unittest/test_json_index.cpp @@ -186,3 +186,86 @@ TEST(JsonIndexTest, TestJsonContains) { } } } + +TEST(JsonIndexTest, TestJsonCast) { + std::vector json_raw_data = { + R"(1)", + R"({"a": 1.0})", + R"({"a": 1})", + R"({"a": "1.0"})", + R"({"a": true})", + R"({"a": [1, 2, 3]})", + R"({"a": {"b": 1}})", + }; + + auto json_path = "/a"; + auto schema = std::make_shared(); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + + auto file_manager_ctx = storage::FileManagerContext(); + file_manager_ctx.fieldDataMeta.field_schema.set_data_type( + milvus::proto::schema::JSON); + file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); + file_manager_ctx.fieldDataMeta.field_id = json_fid.get(); + + auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( + index::INVERTED_INDEX_TYPE, + JsonCastType::FromString("DOUBLE"), + json_path, + file_manager_ctx, + "STRING_TO_DOUBLE"); + auto json_index = std::unique_ptr>( + static_cast*>(inv_index.release())); + + std::vector jsons; + for (auto& json : json_raw_data) { + jsons.push_back(milvus::Json(simdjson::padded_string(json))); + } + + auto json_field = + std::make_shared>(DataType::JSON, false); + json_field->add_json_data(jsons); + json_index->BuildWithFieldData({json_field}); + json_index->finish(); + json_index->create_reader(); + + auto segment = segcore::CreateSealedSegment(schema); + segcore::LoadIndexInfo load_index_info; + load_index_info.field_id = json_fid.get(); + load_index_info.field_type = DataType::JSON; + load_index_info.index = std::move(json_index); + load_index_info.index_params = {{JSON_PATH, json_path}}; + segment->LoadIndex(load_index_info); + + auto cm = milvus::storage::RemoteChunkManagerSingleton::GetInstance() + .GetRemoteChunkManager(); + auto load_info = PrepareSingleFieldInsertBinlog( + 0, 0, 0, json_fid.get(), {json_field}, cm); + segment->LoadFieldData(load_info); + + std::vector>> + test_cases; + + proto::plan::GenericValue value; + value.set_int64_val(1); + test_cases.push_back(std::make_tuple(value, std::vector{1, 2, 3})); + for (auto& test_case : test_cases) { + auto expr = std::make_shared( + expr::ColumnInfo(json_fid, DataType::JSON, {"a"}, true), + proto::plan::OpType::Equal, + value, + std::vector{}); + + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, expr); + + auto result = query::ExecuteQueryExpr( + plan, segment.get(), json_raw_data.size(), MAX_TIMESTAMP); + + auto expect_result = std::get<1>(test_case); + EXPECT_EQ(result.count(), expect_result.size()); + for (auto& id : expect_result) { + EXPECT_TRUE(result[id]); + } + } +} diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index bd32bbee57..454ffe164e 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -195,6 +195,9 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error { if jsonCastType, exist := indexParamsMap[common.JSONCastTypeKey]; exist { indexParamsMap[common.JSONCastTypeKey] = strings.ToUpper(strings.TrimSpace(jsonCastType)) } + if jsonCastFunction, exist := indexParamsMap[common.JSONCastFunctionKey]; exist { + indexParamsMap[common.JSONCastFunctionKey] = strings.ToUpper(strings.TrimSpace(jsonCastFunction)) + } if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil { return err diff --git a/internal/util/indexparamcheck/inverted_checker.go b/internal/util/indexparamcheck/inverted_checker.go index a83cb9bb1d..9b1acd7776 100644 --- a/internal/util/indexparamcheck/inverted_checker.go +++ b/internal/util/indexparamcheck/inverted_checker.go @@ -18,6 +18,8 @@ type INVERTEDChecker struct { var validJSONCastTypes = []string{"BOOL", "DOUBLE", "VARCHAR"} +var validJSONCastFunctions = []string{"STRING_TO_DOUBLE"} + func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { // check json index params isJSONIndex := typeutil.IsJSONType(dataType) @@ -30,6 +32,17 @@ func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[stri if !lo.Contains(validJSONCastTypes, castType) { return merr.WrapErrParameterInvalidMsg("json_cast_type %v is not supported", castType) } + castFunction, exist := params[common.JSONCastFunctionKey] + if exist { + switch castFunction { + case "STRING_TO_DOUBLE": + if castType != "DOUBLE" { + return merr.WrapErrParameterInvalidMsg("json_cast_function %v is not supported for json_cast_type %v", castFunction, castType) + } + default: + return merr.WrapErrParameterInvalidMsg("json_cast_function %v is not supported", castFunction) + } + } } return c.scalarIndexChecker.CheckTrain(dataType, params) } diff --git a/pkg/common/common.go b/pkg/common/common.go index c83d7b096b..dfe5e99e6f 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -151,8 +151,9 @@ const ( ConsistencyLevel = "consistency_level" HintsKey = "hints" - JSONCastTypeKey = "json_cast_type" - JSONPathKey = "json_path" + JSONCastTypeKey = "json_cast_type" + JSONPathKey = "json_path" + JSONCastFunctionKey = "json_cast_function" ) // Doc-in-doc-out