enhance: Support cast function for json index (#41949)

issue: #41948

---------

Signed-off-by: sunby <sunbingyi1992@gmail.com>
This commit is contained in:
Bingyi Sun 2025-06-05 19:42:32 +08:00 committed by GitHub
parent e3826c29ce
commit cc5ac1c220
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 344 additions and 17 deletions

View File

@ -81,6 +81,7 @@ const size_t MARISA_NULL_KEY_ID = -1;
const std::string JSON_CAST_TYPE = "json_cast_type";
const std::string JSON_PATH = "json_path";
const std::string JSON_CAST_FUNCTION = "json_cast_function";
const bool DEFAULT_OPTIMIZE_EXPR_ENABLED = true;
const int64_t DEFAULT_CONVERT_OR_TO_IN_NUMERIC_LIMIT = 150;
const int64_t DEFAULT_JSON_INDEX_MEMORY_BUDGET = 16777216; // bytes, 16MB
@ -99,4 +100,6 @@ const std::string INDEX_NUM_ROWS_KEY = "index_num_rows";
// storage version
const int64_t STORAGE_V1 = 1;
const int64_t STORAGE_V2 = 2;
const int64_t STORAGE_V2 = 2;
const std::string UNKNOW_CAST_FUNCTION_NAME = "unknown";

View File

@ -221,6 +221,17 @@ class Json {
return pointer;
}
auto
type(const std::string& pointer) const {
return pointer.empty() ? doc().type()
: doc().at_pointer(pointer).type();
}
auto get_number_type(const std::string& pointer) const {
return pointer.empty() ? doc().get_number_type()
: doc().at_pointer(pointer).get_number_type();
}
template <typename T>
value_result<T>
at(std::string_view pointer) const {

View File

@ -0,0 +1,113 @@
#include "common/JsonCastFunction.h"
#include <string>
#include "common/EasyAssert.h"
namespace milvus {
const std::unordered_map<std::string, JsonCastFunction>
JsonCastFunction::predefined_cast_functions_ = {
{"STRING_TO_DOUBLE", JsonCastFunction(Type::kString2Double)},
};
JsonCastFunction
JsonCastFunction::FromString(const std::string& str) {
auto it = predefined_cast_functions_.find(str);
if (it != predefined_cast_functions_.end()) {
return it->second;
}
return JsonCastFunction(Type::kUnknown);
}
template <>
std::optional<double>
JsonCastFunction::cast<double, std::string>(const std::string& t) const {
try {
return std::stod(t);
} catch (const std::exception&) {
return std::nullopt;
}
}
template <>
std::optional<double>
JsonCastFunction::cast<double, int64_t>(const int64_t& t) const {
return static_cast<double>(t);
}
template <>
std::optional<double>
JsonCastFunction::cast<double, double>(const double& t) const {
return t;
}
template <>
std::optional<double>
JsonCastFunction::cast<double, bool>(const bool& t) const {
return std::nullopt;
}
template <typename T>
std::optional<T>
JsonCastFunction::CastJsonValue(const JsonCastFunction& cast_function,
const Json& json,
const std::string& pointer) {
AssertInfo(cast_function.match<T>(), "Type mismatch");
auto json_type = json.type(pointer);
std::optional<T> res;
switch (json_type.value()) {
case simdjson::ondemand::json_type::string: {
auto json_value = json.at<std::string_view>(pointer);
res = cast_function.cast<T, std::string>(
std::string(json_value.value()));
break;
}
case simdjson::ondemand::json_type::number: {
if (json.get_number_type(pointer) ==
simdjson::ondemand::number_type::floating_point_number) {
auto json_value = json.at<double>(pointer);
res = cast_function.cast<T, double>(json_value.value());
} else {
auto json_value = json.at<int64_t>(pointer);
res = cast_function.cast<T, int64_t>(json_value.value());
}
break;
}
case simdjson::ondemand::json_type::boolean: {
auto json_value = json.at<bool>(pointer);
res = cast_function.cast<T, bool>(json_value.value());
break;
}
default:
break;
}
return res;
}
template std::optional<bool>
JsonCastFunction::CastJsonValue<bool>(const JsonCastFunction& cast_function,
const Json& json,
const std::string& pointer);
template std::optional<int64_t>
JsonCastFunction::CastJsonValue<int64_t>(const JsonCastFunction& cast_function,
const Json& json,
const std::string& pointer);
template std::optional<double>
JsonCastFunction::CastJsonValue<double>(const JsonCastFunction& cast_function,
const Json& json,
const std::string& pointer);
template std::optional<std::string>
JsonCastFunction::CastJsonValue<std::string>(
const JsonCastFunction& cast_function,
const Json& json,
const std::string& pointer);
} // namespace milvus

View File

@ -0,0 +1,68 @@
#pragma once
#include <functional>
#include <optional>
#include <string>
#include "common/EasyAssert.h"
#include "common/Json.h"
#include "common/JsonCastType.h"
namespace milvus {
class JsonCastFunction {
public:
enum class Type {
kUnknown,
kString2Double,
};
template <typename T, typename F>
std::optional<T>
cast(const F& t) const {
PanicInfo(Unsupported, "Not implemented");
}
template <typename T>
bool
match() const {
switch (cast_function_type_) {
case Type::kString2Double:
return std::is_same_v<T, double>;
}
return false;
}
static JsonCastFunction
FromString(const std::string& str);
template <typename T>
static std::optional<T>
CastJsonValue(const JsonCastFunction& cast_function,
const Json& json,
const std::string& pointer);
private:
JsonCastFunction(Type type) : cast_function_type_(type) {
}
Type cast_function_type_;
static const std::unordered_map<std::string, JsonCastFunction>
predefined_cast_functions_;
};
template <>
std::optional<double>
JsonCastFunction::cast<double, std::string>(const std::string& t) const;
template <>
std::optional<double>
JsonCastFunction::cast<double, int64_t>(const int64_t& t) const;
template <>
std::optional<double>
JsonCastFunction::cast<double, double>(const double& t) const;
template <>
std::optional<double>
JsonCastFunction::cast<double, bool>(const bool& t) const;
} // namespace milvus

View File

@ -384,20 +384,30 @@ IndexFactory::CreateJsonIndex(
IndexType index_type,
JsonCastType cast_dtype,
const std::string& nested_path,
const storage::FileManagerContext& file_manager_context) {
const storage::FileManagerContext& file_manager_context,
const std::string& json_cast_function) {
AssertInfo(index_type == INVERTED_INDEX_TYPE,
"Invalid index type for json index");
switch (cast_dtype.element_type()) {
case JsonCastType::DataType::BOOL:
return std::make_unique<index::JsonInvertedIndex<bool>>(
cast_dtype, nested_path, file_manager_context);
cast_dtype,
nested_path,
file_manager_context,
JsonCastFunction::FromString(json_cast_function));
case JsonCastType::DataType::DOUBLE:
return std::make_unique<index::JsonInvertedIndex<double>>(
cast_dtype, nested_path, file_manager_context);
cast_dtype,
nested_path,
file_manager_context,
JsonCastFunction::FromString(json_cast_function));
case JsonCastType::DataType::VARCHAR:
return std::make_unique<index::JsonInvertedIndex<std::string>>(
cast_dtype, nested_path, file_manager_context);
cast_dtype,
nested_path,
file_manager_context,
JsonCastFunction::FromString(json_cast_function));
default:
PanicInfo(DataTypeInvalid, "Invalid data type:{}", cast_dtype);
}
@ -428,7 +438,8 @@ IndexFactory::CreateScalarIndex(
return CreateJsonIndex(create_index_info.index_type,
create_index_info.json_cast_type,
create_index_info.json_path,
file_manager_context);
file_manager_context,
create_index_info.json_cast_function);
}
default:
PanicInfo(DataTypeInvalid, "Invalid data type:{}", data_type);

View File

@ -113,7 +113,9 @@ class IndexFactory {
JsonCastType cast_dtype,
const std::string& nested_path,
const storage::FileManagerContext& file_manager_context =
storage::FileManagerContext());
storage::FileManagerContext(),
const std::string& json_cast_function =
UNKNOW_CAST_FUNCTION_NAME);
IndexBasePtr
CreateScalarIndex(const CreateIndexInfo& create_index_info,

View File

@ -31,6 +31,7 @@ struct CreateIndexInfo {
uint32_t tantivy_index_version{7};
JsonCastType json_cast_type{JsonCastType::UNKNOWN};
std::string json_path;
std::string json_cast_function;
};
} // namespace milvus::index

View File

@ -148,13 +148,21 @@ JsonInvertedIndex<T>::build_index_for_json(
}
}
} else {
value_result<SIMDJSON_T> res =
json_column->at<SIMDJSON_T>(nested_path_);
if (res.error() != simdjson::SUCCESS) {
error_recorder_.Record(
*json_column, nested_path_, res.error());
if (cast_function_.match<T>()) {
auto res = JsonCastFunction::CastJsonValue<T>(
cast_function_, *json_column, nested_path_);
if (res.has_value()) {
values.push_back(res.value());
}
} else {
values.push_back(static_cast<T>(res.value()));
value_result<SIMDJSON_T> res =
json_column->at<SIMDJSON_T>(nested_path_);
if (res.error() != simdjson::SUCCESS) {
error_recorder_.Record(
*json_column, nested_path_, res.error());
} else {
values.push_back(static_cast<T>(res.value()));
}
}
}
this->wrapper_->template add_array_data<T>(

View File

@ -12,6 +12,7 @@
#pragma once
#include <cstdint>
#include "common/FieldDataInterface.h"
#include "common/JsonCastFunction.h"
#include "common/JsonCastType.h"
#include "index/InvertedIndexTantivy.h"
#include "index/ScalarIndex.h"
@ -69,8 +70,12 @@ class JsonInvertedIndex : public index::InvertedIndexTantivy<T> {
public:
JsonInvertedIndex(const JsonCastType& cast_type,
const std::string& nested_path,
const storage::FileManagerContext& ctx)
: nested_path_(nested_path), cast_type_(cast_type) {
const storage::FileManagerContext& ctx,
const JsonCastFunction& cast_function =
JsonCastFunction::FromString("unknown"))
: nested_path_(nested_path),
cast_type_(cast_type),
cast_function_(cast_function) {
this->schema_ = ctx.fieldDataMeta.field_schema;
this->mem_file_manager_ =
std::make_shared<storage::MemFileManagerImpl>(ctx);
@ -127,6 +132,7 @@ class JsonInvertedIndex : public index::InvertedIndexTantivy<T> {
std::string nested_path_;
JsonInvertedIndexParseErrorRecorder error_recorder_;
JsonCastType cast_type_;
JsonCastFunction cast_function_;
};
} // namespace milvus::index

View File

@ -52,6 +52,10 @@ ScalarIndexCreator::ScalarIndexCreator(
index_info.json_cast_type = milvus::JsonCastType::FromString(
config.at(JSON_CAST_TYPE).get<std::string>());
index_info.json_path = config.at(JSON_PATH).get<std::string>();
if (config.contains(JSON_CAST_FUNCTION)) {
index_info.json_cast_function =
config.at(JSON_CAST_FUNCTION).get<std::string>();
}
}
index_ = index::IndexFactory::GetInstance().CreateIndex(
index_info, file_manager_context);

View File

@ -186,3 +186,86 @@ TEST(JsonIndexTest, TestJsonContains) {
}
}
}
TEST(JsonIndexTest, TestJsonCast) {
std::vector<std::string> json_raw_data = {
R"(1)",
R"({"a": 1.0})",
R"({"a": 1})",
R"({"a": "1.0"})",
R"({"a": true})",
R"({"a": [1, 2, 3]})",
R"({"a": {"b": 1}})",
};
auto json_path = "/a";
auto schema = std::make_shared<Schema>();
auto json_fid = schema->AddDebugField("json", DataType::JSON);
auto file_manager_ctx = storage::FileManagerContext();
file_manager_ctx.fieldDataMeta.field_schema.set_data_type(
milvus::proto::schema::JSON);
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
file_manager_ctx.fieldDataMeta.field_id = json_fid.get();
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
index::INVERTED_INDEX_TYPE,
JsonCastType::FromString("DOUBLE"),
json_path,
file_manager_ctx,
"STRING_TO_DOUBLE");
auto json_index = std::unique_ptr<JsonInvertedIndex<double>>(
static_cast<JsonInvertedIndex<double>*>(inv_index.release()));
std::vector<milvus::Json> jsons;
for (auto& json : json_raw_data) {
jsons.push_back(milvus::Json(simdjson::padded_string(json)));
}
auto json_field =
std::make_shared<FieldData<milvus::Json>>(DataType::JSON, false);
json_field->add_json_data(jsons);
json_index->BuildWithFieldData({json_field});
json_index->finish();
json_index->create_reader();
auto segment = segcore::CreateSealedSegment(schema);
segcore::LoadIndexInfo load_index_info;
load_index_info.field_id = json_fid.get();
load_index_info.field_type = DataType::JSON;
load_index_info.index = std::move(json_index);
load_index_info.index_params = {{JSON_PATH, json_path}};
segment->LoadIndex(load_index_info);
auto cm = milvus::storage::RemoteChunkManagerSingleton::GetInstance()
.GetRemoteChunkManager();
auto load_info = PrepareSingleFieldInsertBinlog(
0, 0, 0, json_fid.get(), {json_field}, cm);
segment->LoadFieldData(load_info);
std::vector<std::tuple<proto::plan::GenericValue, std::vector<int64_t>>>
test_cases;
proto::plan::GenericValue value;
value.set_int64_val(1);
test_cases.push_back(std::make_tuple(value, std::vector<int64_t>{1, 2, 3}));
for (auto& test_case : test_cases) {
auto expr = std::make_shared<expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(json_fid, DataType::JSON, {"a"}, true),
proto::plan::OpType::Equal,
value,
std::vector<proto::plan::GenericValue>{});
auto plan =
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, expr);
auto result = query::ExecuteQueryExpr(
plan, segment.get(), json_raw_data.size(), MAX_TIMESTAMP);
auto expect_result = std::get<1>(test_case);
EXPECT_EQ(result.count(), expect_result.size());
for (auto& id : expect_result) {
EXPECT_TRUE(result[id]);
}
}
}

View File

@ -195,6 +195,9 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
if jsonCastType, exist := indexParamsMap[common.JSONCastTypeKey]; exist {
indexParamsMap[common.JSONCastTypeKey] = strings.ToUpper(strings.TrimSpace(jsonCastType))
}
if jsonCastFunction, exist := indexParamsMap[common.JSONCastFunctionKey]; exist {
indexParamsMap[common.JSONCastFunctionKey] = strings.ToUpper(strings.TrimSpace(jsonCastFunction))
}
if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil {
return err

View File

@ -18,6 +18,8 @@ type INVERTEDChecker struct {
var validJSONCastTypes = []string{"BOOL", "DOUBLE", "VARCHAR"}
var validJSONCastFunctions = []string{"STRING_TO_DOUBLE"}
func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
// check json index params
isJSONIndex := typeutil.IsJSONType(dataType)
@ -30,6 +32,17 @@ func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[stri
if !lo.Contains(validJSONCastTypes, castType) {
return merr.WrapErrParameterInvalidMsg("json_cast_type %v is not supported", castType)
}
castFunction, exist := params[common.JSONCastFunctionKey]
if exist {
switch castFunction {
case "STRING_TO_DOUBLE":
if castType != "DOUBLE" {
return merr.WrapErrParameterInvalidMsg("json_cast_function %v is not supported for json_cast_type %v", castFunction, castType)
}
default:
return merr.WrapErrParameterInvalidMsg("json_cast_function %v is not supported", castFunction)
}
}
}
return c.scalarIndexChecker.CheckTrain(dataType, params)
}

View File

@ -151,8 +151,9 @@ const (
ConsistencyLevel = "consistency_level"
HintsKey = "hints"
JSONCastTypeKey = "json_cast_type"
JSONPathKey = "json_path"
JSONCastTypeKey = "json_cast_type"
JSONPathKey = "json_path"
JSONCastFunctionKey = "json_cast_function"
)
// Doc-in-doc-out