From ae8a9cebb85cab531adfdae5191cb7e01b20cb4e Mon Sep 17 00:00:00 2001 From: Bingyi Sun Date: Fri, 14 Mar 2025 21:36:10 +0800 Subject: [PATCH] fix: fix json index not-equal filter (#40648) issue: #35528 pr: https://github.com/milvus-io/milvus/pull/40647 --------- Signed-off-by: sunby --- internal/core/src/common/JsonCastType.h | 39 ++++++++++ internal/core/src/index/IndexFactory.cpp | 15 ++-- internal/core/src/index/IndexFactory.h | 3 +- internal/core/src/index/IndexInfo.h | 3 +- internal/core/src/index/JsonInvertedIndex.cpp | 10 +-- .../src/indexbuilder/ScalarIndexCreator.cpp | 5 +- internal/core/src/segcore/load_index_c.cpp | 5 +- internal/core/unittest/test_expr.cpp | 75 +++++++++++++++++-- internal/datacoord/index_service_test.go | 29 ++++--- internal/proxy/task_index.go | 4 + internal/proxy/task_index_test.go | 4 +- .../util/indexparamcheck/inverted_checker.go | 11 +-- .../indexparamcheck/inverted_checker_test.go | 5 +- tests/python_client/testcases/test_index.py | 2 +- 14 files changed, 154 insertions(+), 56 deletions(-) create mode 100644 internal/core/src/common/JsonCastType.h diff --git a/internal/core/src/common/JsonCastType.h b/internal/core/src/common/JsonCastType.h new file mode 100644 index 0000000000..cb49d027f7 --- /dev/null +++ b/internal/core/src/common/JsonCastType.h @@ -0,0 +1,39 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include "common/EasyAssert.h" +namespace milvus { + +enum class JsonCastType { BOOL, DOUBLE, VARCHAR }; + +inline auto +format_as(JsonCastType f) { + return fmt::underlying(f); +} + +inline const std::unordered_map JsonCastTypeMap = { + {"BOOL", JsonCastType::BOOL}, + {"DOUBLE", JsonCastType::DOUBLE}, + {"VARCHAR", JsonCastType::VARCHAR}}; + +inline JsonCastType +ConvertToJsonCastType(const std::string& str) { + auto it = JsonCastTypeMap.find(str); + if (it != JsonCastTypeMap.end()) { + return it->second; + } + PanicInfo(Unsupported, "Invalid json cast type: " + str); +} + +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index 2d7871d683..3448c99104 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -19,6 +19,7 @@ #include #include "common/EasyAssert.h" #include "common/FieldDataInterface.h" +#include "common/JsonCastType.h" #include "common/Types.h" #include "index/VectorMemIndex.h" #include "index/Utils.h" @@ -374,29 +375,23 @@ IndexFactory::CreateComplexScalarIndex( IndexBasePtr IndexFactory::CreateJsonIndex( IndexType index_type, - DataType cast_dtype, + JsonCastType cast_dtype, const std::string& nested_path, const storage::FileManagerContext& file_manager_context) { AssertInfo(index_type == INVERTED_INDEX_TYPE, "Invalid index type for json index"); switch (cast_dtype) { - case DataType::BOOL: + case JsonCastType::BOOL: return std::make_unique>( proto::schema::DataType::Bool, nested_path, file_manager_context); - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: - case DataType::FLOAT: - case DataType::DOUBLE: + case JsonCastType::DOUBLE: return std::make_unique>( proto::schema::DataType::Double, nested_path, file_manager_context); - case DataType::STRING: - case DataType::VARCHAR: + case JsonCastType::VARCHAR: return std::make_unique>( proto::schema::DataType::String, nested_path, diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index 786c06f84b..f8417b95b1 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -21,6 +21,7 @@ #include #include +#include "common/JsonCastType.h" #include "common/Types.h" #include "common/type_c.h" #include "index/Index.h" @@ -106,7 +107,7 @@ class IndexFactory { IndexBasePtr CreateJsonIndex(IndexType index_type, - DataType cast_dtype, + JsonCastType cast_dtype, const std::string& nested_path, const storage::FileManagerContext& file_manager_context = storage::FileManagerContext()); diff --git a/internal/core/src/index/IndexInfo.h b/internal/core/src/index/IndexInfo.h index 3c730baeeb..a86258f8b5 100644 --- a/internal/core/src/index/IndexInfo.h +++ b/internal/core/src/index/IndexInfo.h @@ -15,6 +15,7 @@ // limitations under the License. #pragma once +#include "common/JsonCastType.h" #include "common/Types.h" namespace milvus::index { @@ -27,7 +28,7 @@ struct CreateIndexInfo { std::string field_name; int64_t dim; int32_t scalar_index_engine_version; - DataType json_cast_type; + JsonCastType json_cast_type; std::string json_path; }; diff --git a/internal/core/src/index/JsonInvertedIndex.cpp b/internal/core/src/index/JsonInvertedIndex.cpp index 24cb1a4e95..b36bfa0d75 100644 --- a/internal/core/src/index/JsonInvertedIndex.cpp +++ b/internal/core/src/index/JsonInvertedIndex.cpp @@ -35,7 +35,7 @@ JsonInvertedIndex::build_index_for_json( for (int64_t i = 0; i < n; i++) { auto json_column = static_cast(data->RawValue(i)); if (this->schema_.nullable() && !data->is_valid(i)) { - this->null_offset_.push_back(i); + this->null_offset_.push_back(offset); this->wrapper_->template add_multi_data( nullptr, 0, offset++); continue; @@ -51,12 +51,10 @@ JsonInvertedIndex::build_index_for_json( err, *json_column, nested_path_); - if (err == simdjson::INVALID_JSON_POINTER) { - LOG_WARN("Invalid json pointer, json: {}, pointer: {}", - *json_column, - nested_path_); + if (err == simdjson::NO_SUCH_FIELD || + err == simdjson::INVALID_JSON_POINTER) { + this->null_offset_.push_back(offset); } - this->null_offset_.push_back(i); this->wrapper_->template add_multi_data( nullptr, 0, offset++); continue; diff --git a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp index 0872871eea..c06a16a8a4 100644 --- a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp +++ b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp @@ -12,6 +12,7 @@ #include "indexbuilder/ScalarIndexCreator.h" #include "common/Consts.h" #include "common/FieldDataInterface.h" +#include "common/JsonCastType.h" #include "common/Types.h" #include "index/IndexFactory.h" #include "index/IndexInfo.h" @@ -43,8 +44,8 @@ ScalarIndexCreator::ScalarIndexCreator( index_info.field_type = dtype_; index_info.index_type = index_type(); if (dtype == DataType::JSON) { - index_info.json_cast_type = static_cast( - std::stoi(config.at(JSON_CAST_TYPE).get())); + index_info.json_cast_type = + ConvertToJsonCastType(config.at(JSON_CAST_TYPE).get()); index_info.json_path = config.at(JSON_PATH).get(); } index_ = index::IndexFactory::GetInstance().CreateIndex( diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 185bc97012..b847c4d297 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -14,6 +14,7 @@ #include "common/Consts.h" #include "common/FieldMeta.h" #include "common/EasyAssert.h" +#include "common/JsonCastType.h" #include "common/Types.h" #include "common/type_c.h" #include "index/Index.h" @@ -307,8 +308,8 @@ AppendIndexV2(CTraceContext c_trace, CLoadIndexInfo c_load_index_info) { config[milvus::index::INDEX_FILES] = load_index_info->index_files; if (load_index_info->field_type == milvus::DataType::JSON) { - index_info.json_cast_type = static_cast( - std::stoi(config.at(JSON_CAST_TYPE).get())); + index_info.json_cast_type = milvus::ConvertToJsonCastType( + config.at(JSON_CAST_TYPE).get()); index_info.json_path = config.at(JSON_PATH).get(); } milvus::storage::FileManagerContext fileManagerContext( diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index c4a9d9ac6d..a41bc3d61c 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -16395,21 +16395,21 @@ class JsonIndexTestFixture : public testing::Test { json_path = "/bool"; lower_bound.set_bool_val(std::numeric_limits::min()); upper_bound.set_bool_val(std::numeric_limits::max()); - cast_type = milvus::DataType::BOOL; + cast_type = JsonCastType::BOOL; wrong_type_val.set_int64_val(123); } else if constexpr (std::is_same_v) { schema_data_type = proto::schema::Int64; json_path = "/int"; lower_bound.set_int64_val(std::numeric_limits::min()); upper_bound.set_int64_val(std::numeric_limits::max()); - cast_type = milvus::DataType::INT64; + cast_type = JsonCastType::DOUBLE; wrong_type_val.set_string_val("123"); } else if constexpr (std::is_same_v) { schema_data_type = proto::schema::Double; json_path = "/double"; lower_bound.set_float_val(std::numeric_limits::min()); upper_bound.set_float_val(std::numeric_limits::max()); - cast_type = milvus::DataType::DOUBLE; + cast_type = JsonCastType::DOUBLE; wrong_type_val.set_string_val("123"); } else if constexpr (std::is_same_v) { schema_data_type = proto::schema::String; @@ -16417,7 +16417,7 @@ class JsonIndexTestFixture : public testing::Test { lower_bound.set_string_val(""); std::string s(1024, '9'); upper_bound.set_string_val(s); - cast_type = milvus::DataType::STRING; + cast_type = JsonCastType::VARCHAR; wrong_type_val.set_int64_val(123); } } @@ -16425,7 +16425,7 @@ class JsonIndexTestFixture : public testing::Test { std::string json_path; proto::plan::GenericValue lower_bound; proto::plan::GenericValue upper_bound; - milvus::DataType cast_type; + JsonCastType cast_type; proto::plan::GenericValue wrong_type_val; }; @@ -16560,3 +16560,68 @@ TYPED_TEST(JsonIndexTestFixture, TestJsonIndexUnaryExpr) { final = ExecuteQueryExpr(plan, seg.get(), N, MAX_TIMESTAMP); EXPECT_EQ(final.count(), expect_count); } + +TEST(JsonIndexTest, TestJsonNotEqualExpr) { + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("age64", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateSealedSegment(schema); + segcore::LoadIndexInfo load_index_info; + + auto file_manager_ctx = storage::FileManagerContext(); + file_manager_ctx.fieldDataMeta.field_schema.set_data_type( + milvus::proto::schema::JSON); + file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); + auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( + index::INVERTED_INDEX_TYPE, + JsonCastType::DOUBLE, + "/a", + file_manager_ctx); + + using json_index_type = index::JsonInvertedIndex; + auto json_index = std::unique_ptr( + static_cast(inv_index.release())); + auto json_strs = std::vector{ + R"({"a": 1.0})", R"({"a": "abc"})", R"({"a": 3.0})", R"({"a": null})"}; + auto json_field = + std::make_shared>(DataType::JSON, false); + auto json_field2 = + std::make_shared>(DataType::JSON, false); + std::vector jsons; + + for (auto& json : json_strs) { + jsons.push_back(milvus::Json(simdjson::padded_string(json))); + } + json_field->add_json_data(jsons); + json_field2->add_json_data(jsons); + + json_index->BuildWithFieldData({json_field, json_field2}); + json_index->finish(); + json_index->create_reader(); + + load_index_info.field_id = json_fid.get(); + load_index_info.field_type = DataType::JSON; + load_index_info.index = std::move(json_index); + load_index_info.index_params = {{JSON_PATH, "/a"}}; + seg->LoadIndex(load_index_info); + + auto json_field_data_info = FieldDataInfo( + json_fid.get(), 2 * json_strs.size(), {json_field, json_field2}); + seg->LoadFieldData(json_fid, json_field_data_info); + + proto::plan::GenericValue val; + val.set_int64_val(1); + auto unary_expr = std::make_shared( + expr::ColumnInfo(json_fid, DataType::JSON, {"a"}), + proto::plan::OpType::NotEqual, + val); + auto plan = + std::make_shared(DEFAULT_PLANNODE_ID, unary_expr); + auto final = + ExecuteQueryExpr(plan, seg.get(), 2 * json_strs.size(), MAX_TIMESTAMP); + EXPECT_EQ(final.count(), 2 * json_strs.size() - 2); +} diff --git a/internal/datacoord/index_service_test.go b/internal/datacoord/index_service_test.go index 733019c9c5..706dcf020d 100644 --- a/internal/datacoord/index_service_test.go +++ b/internal/datacoord/index_service_test.go @@ -19,7 +19,6 @@ package datacoord import ( "context" "fmt" - "strconv" "testing" "time" @@ -212,7 +211,7 @@ func TestServer_CreateIndex(t *testing.T) { }, { Key: common.JSONCastTypeKey, - Value: "int64", + Value: "double", }, { Key: common.IndexTypeKey, @@ -2688,7 +2687,7 @@ func TestJsonIndex(t *testing.T) { req := &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "a", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_String))}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "varchar"}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, } resp, err := s.CreateIndex(context.Background(), req) assert.NoError(t, merr.CheckRPCCall(resp, err)) @@ -2696,7 +2695,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_String))}, {Key: common.JSONPathKey, Value: "json[\"c\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "varchar"}, {Key: common.JSONPathKey, Value: "json[\"c\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.NoError(t, merr.CheckRPCCall(resp, err)) @@ -2705,7 +2704,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 1, IndexName: "", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_String))}, {Key: common.JSONPathKey, Value: "json2[\"c\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "varchar"}, {Key: common.JSONPathKey, Value: "json2[\"c\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.NoError(t, merr.CheckRPCCall(resp, err)) @@ -2714,7 +2713,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "a", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_String))}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "varchar"}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.NoError(t, merr.CheckRPCCall(resp, err)) @@ -2723,7 +2722,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "a", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.Error(t, merr.CheckRPCCall(resp, err)) @@ -2732,7 +2731,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "b", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "json[\"a\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.Error(t, merr.CheckRPCCall(resp, err)) @@ -2741,7 +2740,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "a", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "json[\"b\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "json[\"b\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.Error(t, merr.CheckRPCCall(resp, err)) @@ -2759,7 +2758,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 1, IndexName: "c", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "bad_json[\"a\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "bad_json[\"a\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.Error(t, merr.CheckRPCCall(resp, err)) @@ -2768,7 +2767,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 2, IndexName: "", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "dynamic_a_field"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "dynamic_a_field"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.NoError(t, merr.CheckRPCCall(resp, err)) @@ -2777,7 +2776,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "d", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "json[a][\"b\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "json[a][\"b\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.Error(t, merr.CheckRPCCall(resp, err)) @@ -2786,7 +2785,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "e", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "json[\"a\"][\"b"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "json[\"a\"][\"b"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.Error(t, merr.CheckRPCCall(resp, err)) @@ -2795,7 +2794,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "f", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "json[\"a\"[\"b]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "json[\"a\"[\"b]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.Error(t, merr.CheckRPCCall(resp, err)) @@ -2804,7 +2803,7 @@ func TestJsonIndex(t *testing.T) { req = &indexpb.CreateIndexRequest{ FieldID: 0, IndexName: "g", - IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: strconv.Itoa(int(schemapb.DataType_Int16))}, {Key: common.JSONPathKey, Value: "json[\"a\"][0][\"b\"]"}}, + IndexParams: []*commonpb.KeyValuePair{{Key: common.JSONCastTypeKey, Value: "double"}, {Key: common.JSONPathKey, Value: "json[\"a\"][0][\"b\"]"}}, } resp, err = s.CreateIndex(context.Background(), req) assert.NoError(t, merr.CheckRPCCall(resp, err)) diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 4d6c3ea53e..4edd16f541 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -193,6 +193,10 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error { } } + if jsonCastType, exist := indexParamsMap[common.JSONCastTypeKey]; exist { + indexParamsMap[common.JSONCastTypeKey] = strings.ToUpper(strings.TrimSpace(jsonCastType)) + } + if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil { return err } diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 7c1f12570c..7dabcdf4a5 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -1141,7 +1141,7 @@ func Test_parseIndexParams(t *testing.T) { ExtraParams: []*commonpb.KeyValuePair{ { Key: common.JSONCastTypeKey, - Value: "1", + Value: "double", }, { Key: common.IndexTypeKey, @@ -1170,7 +1170,7 @@ func Test_parseIndexParams(t *testing.T) { ExtraParams: []*commonpb.KeyValuePair{ { Key: common.JSONCastTypeKey, - Value: "1", + Value: "double", }, { Key: common.IndexTypeKey, diff --git a/internal/util/indexparamcheck/inverted_checker.go b/internal/util/indexparamcheck/inverted_checker.go index 61914bf9c0..a83cb9bb1d 100644 --- a/internal/util/indexparamcheck/inverted_checker.go +++ b/internal/util/indexparamcheck/inverted_checker.go @@ -2,7 +2,6 @@ package indexparamcheck import ( "fmt" - "strconv" "github.com/samber/lo" @@ -17,7 +16,7 @@ type INVERTEDChecker struct { scalarIndexChecker } -var validJSONCastTypes = []int{int(schemapb.DataType_Bool), int(schemapb.DataType_Int8), int(schemapb.DataType_Int16), int(schemapb.DataType_Int32), int(schemapb.DataType_Int64), int(schemapb.DataType_Float), int(schemapb.DataType_Double), int(schemapb.DataType_String), int(schemapb.DataType_VarChar)} +var validJSONCastTypes = []string{"BOOL", "DOUBLE", "VARCHAR"} func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { // check json index params @@ -28,12 +27,8 @@ func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[stri return merr.WrapErrParameterMissing(common.JSONCastTypeKey, "json index must specify cast type") } - castTypeInt, err := strconv.Atoi(castType) - if err != nil { - return merr.WrapErrParameterInvalid(common.JSONCastTypeKey, "json_cast_type must be DataType") - } - if !lo.Contains(validJSONCastTypes, castTypeInt) { - return merr.WrapErrParameterInvalid(common.JSONCastTypeKey, "json_cast_type is not supported") + if !lo.Contains(validJSONCastTypes, castType) { + return merr.WrapErrParameterInvalidMsg("json_cast_type %v is not supported", castType) } } return c.scalarIndexChecker.CheckTrain(dataType, params) diff --git a/internal/util/indexparamcheck/inverted_checker_test.go b/internal/util/indexparamcheck/inverted_checker_test.go index 8f16cba83d..a5726a3a1d 100644 --- a/internal/util/indexparamcheck/inverted_checker_test.go +++ b/internal/util/indexparamcheck/inverted_checker_test.go @@ -1,7 +1,6 @@ package indexparamcheck import ( - "strconv" "testing" "github.com/stretchr/testify/assert" @@ -27,7 +26,7 @@ func Test_INVERTEDIndexChecker(t *testing.T) { func Test_CheckTrain(t *testing.T) { c := newINVERTEDChecker() - assert.NoError(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": strconv.Itoa(int(schemapb.DataType_Bool)), "json_path": "json['a']"})) - assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": strconv.Itoa(int(schemapb.DataType_Array)), "json_path": "json['a']"})) + assert.NoError(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": "BOOL", "json_path": "json['a']"})) + assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": "array", "json_path": "json['a']"})) assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": "abc", "json_path": "json['a']"})) } diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index 930a7d8765..a1857b2bd7 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -1234,7 +1234,7 @@ class TestIndexInvalid(TestcaseBase): expected: success """ collection_w = self.init_collection_general(prefix, is_index=False, vector_data_type=vector_data_type)[0] - scalar_index_params = {"index_type": "INVERTED", "json_cast_type": DataType.INT32, "json_path": ct.default_json_field_name+"['a']"} + scalar_index_params = {"index_type": "INVERTED", "json_cast_type": "double", "json_path": ct.default_json_field_name+"['a']"} collection_w.create_index(ct.default_json_field_name, index_params=scalar_index_params) @pytest.mark.tags(CaseLabel.L1)