fix: fix searchPlan metricType modified concurrently (#30227)

issue: #30225
/kind bug
Signed-off-by: xige-16 <xi.ge@zilliz.com>

---------

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2024-01-26 14:03:09 +08:00 committed by GitHub
parent 7ced0af197
commit e9fdd2475d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 101 additions and 87 deletions

View File

@ -35,10 +35,15 @@ class Collection {
}
IndexMetaPtr&
GetIndexMeta() {
get_index_meta() {
return index_meta_;
}
void
set_index_meta(const IndexMetaPtr index_meta) {
index_meta_ = index_meta;
}
const std::string_view
get_collection_name() {
return collection_name_;

View File

@ -278,7 +278,6 @@ class SegmentGrowingImpl : public SegmentGrowing {
void
check_search(const query::Plan* plan) const override {
Assert(plan);
check_metric_type(plan, index_meta_);
}
const ConcurrentVector<Timestamp>&

View File

@ -316,23 +316,4 @@ SegmentInternalInterface::LoadStringSkipIndex(
skip_index_.LoadString(field_id, chunk_id, var_column);
}
void
SegmentInternalInterface::check_metric_type(
const query::Plan* plan, const IndexMetaPtr index_meta) const {
auto& metric_str = plan->plan_node_->search_info_.metric_type_;
auto searched_field_id = plan->plan_node_->search_info_.field_id_;
auto field_index_meta =
index_meta->GetFieldIndexMeta(FieldId(searched_field_id));
if (metric_str.empty()) {
metric_str = field_index_meta.GeMetricType();
}
if (metric_str != field_index_meta.GeMetricType()) {
throw SegcoreError(
MetricTypeNotMatch,
fmt::format("metric type not match, expected {}, actual {}.",
field_index_meta.GeMetricType(),
metric_str));
}
}
} // namespace milvus::segcore

View File

@ -242,10 +242,6 @@ class SegmentInternalInterface : public SegmentInterface {
virtual std::pair<std::unique_ptr<IdArray>, std::vector<SegOffset>>
search_ids(const IdArray& id_array, Timestamp timestamp) const = 0;
void
check_metric_type(const query::Plan* plan,
const IndexMetaPtr index_meta) const;
/**
* Apply timestamp filtering on bitset, the query can't see an entity whose
* timestamp is bigger than the timestamp of query.

View File

@ -963,8 +963,6 @@ SegmentSealedImpl::check_search(const query::Plan* plan) const {
AssertInfo(plan->extra_info_opt_.has_value(),
"Extra info of search plan doesn't have value");
check_metric_type(plan, col_index_meta_);
if (!is_system_field_ready()) {
PanicInfo(
FieldNotLoaded,

View File

@ -25,6 +25,13 @@ CreateSearchPlanByExpr(CCollection c_col,
try {
auto res = milvus::query::CreateSearchPlanByExpr(
*col->get_schema(), serialized_expr_plan, size);
auto col_index_meta = col->get_index_meta();
auto field_id = milvus::query::GetFieldID(res.get());
AssertInfo(col_index_meta != nullptr, "index meta not exist");
auto field_index_meta =
col_index_meta->GetFieldIndexMeta(milvus::FieldId(field_id));
res->plan_node_->search_info_.metric_type_ =
field_index_meta.GeMetricType();
auto status = CStatus();
status.error_code = milvus::Success;

View File

@ -41,14 +41,14 @@ NewSegment(CCollection collection,
switch (seg_type) {
case Growing: {
auto seg = milvus::segcore::CreateGrowingSegment(
col->get_schema(), col->GetIndexMeta(), segment_id);
col->get_schema(), col->get_index_meta(), segment_id);
segment = std::move(seg);
break;
}
case Sealed:
case Indexing:
segment = milvus::segcore::CreateSealedSegment(
col->get_schema(), col->GetIndexMeta(), segment_id);
col->get_schema(), col->get_index_meta(), segment_id);
break;
default:
PanicInfo(milvus::UnexpectedError,

View File

@ -293,7 +293,9 @@ enum VectorType {
};
std::string
generate_collection_schema(std::string metric_type, int dim, VectorType vector_type) {
generate_collection_schema(std::string metric_type,
int dim,
VectorType vector_type) {
namespace schema = milvus::proto::schema;
schema::CollectionSchema collection_schema;
collection_schema.set_name("collection_test");
@ -425,8 +427,8 @@ TEST(CApiTest, SegmentTest) {
}
TEST(CApiTest, CPlan) {
std::string schema_string =
generate_collection_schema(knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection = NewCollection(schema_string.c_str());
// const char* dsl_string = R"(
@ -481,8 +483,8 @@ TEST(CApiTest, CPlan) {
}
TEST(CApiTest, CApiCPlan_float16) {
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, 16, VectorType::Float16Vector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, 16, VectorType::Float16Vector);
auto collection = NewCollection(schema_string.c_str());
milvus::proto::plan::PlanNode plan_node;
@ -521,8 +523,8 @@ TEST(CApiTest, CApiCPlan_float16) {
}
TEST(CApiTest, CApiCPlan_bfloat16) {
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, 16, VectorType::BFloat16Vector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, 16, VectorType::BFloat16Vector);
auto collection = NewCollection(schema_string.c_str());
milvus::proto::plan::PlanNode plan_node;
@ -2041,8 +2043,8 @@ TEST(CApiTest, Indexing_Without_Predicate) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -2192,8 +2194,8 @@ TEST(CApiTest, Indexing_Expr_Without_Predicate) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -2344,8 +2346,8 @@ TEST(CApiTest, Indexing_With_float_Predicate_Range) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -2524,8 +2526,8 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Range) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -2706,8 +2708,8 @@ TEST(CApiTest, Indexing_With_float_Predicate_Term) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -2880,8 +2882,8 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Term) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -3055,9 +3057,10 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Range) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection = NewCollection(schema_string.c_str());
std::string schema_string = generate_collection_schema(
knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection =
NewCollection(schema_string.c_str(), knowhere::metric::JACCARD);
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
auto status = NewSegment(collection, Growing, -1, &segment);
@ -3236,9 +3239,10 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Range) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection = NewCollection(schema_string.c_str());
std::string schema_string = generate_collection_schema(
knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection =
NewCollection(schema_string.c_str(), knowhere::metric::JACCARD);
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
auto status = NewSegment(collection, Growing, -1, &segment);
@ -3417,9 +3421,10 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection = NewCollection(schema_string.c_str());
std::string schema_string = generate_collection_schema(
knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection =
NewCollection(schema_string.c_str(), knowhere::metric::JACCARD);
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
auto status = NewSegment(collection, Growing, -1, &segment);
@ -3614,9 +3619,10 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection = NewCollection(schema_string.c_str());
std::string schema_string = generate_collection_schema(
knowhere::metric::JACCARD, DIM, VectorType::BinaryVector);
auto collection =
NewCollection(schema_string.c_str(), knowhere::metric::JACCARD);
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
auto status = NewSegment(collection, Growing, -1, &segment);
@ -3828,8 +3834,8 @@ TEST(CApiTest, SealedSegmentTest) {
TEST(CApiTest, SealedSegment_search_float_Predicate_Range) {
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -3982,8 +3988,8 @@ TEST(CApiTest, SealedSegment_search_float_Predicate_Range) {
TEST(CApiTest, SealedSegment_search_without_predicates) {
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -4062,8 +4068,8 @@ TEST(CApiTest, SealedSegment_search_without_predicates) {
TEST(CApiTest, SealedSegment_search_float_With_Expr_Predicate_Range) {
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::FloatVector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::FloatVector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -4510,7 +4516,8 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_WHEN_IP) {
}
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP) {
auto c_collection = NewCollection(get_default_schema_config());
auto c_collection =
NewCollection(get_default_schema_config(), knowhere::metric::IP);
CSegmentInterface segment;
auto status = NewSegment(c_collection, Growing, -1, &segment);
ASSERT_EQ(status.error_code, Success);
@ -4853,8 +4860,8 @@ TEST(CApiTest, Indexing_Without_Predicate_float16) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::Float16Vector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::Float16Vector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -4881,7 +4888,8 @@ TEST(CApiTest, Indexing_Without_Predicate_float16) {
milvus::proto::plan::PlanNode plan_node;
auto vector_anns = plan_node.mutable_vector_anns();
vector_anns->set_vector_type(milvus::proto::plan::VectorType::Float16Vector);
vector_anns->set_vector_type(
milvus::proto::plan::VectorType::Float16Vector);
vector_anns->set_placeholder_tag("$0");
vector_anns->set_field_id(100);
auto query_info = vector_anns->mutable_query_info();
@ -4969,7 +4977,7 @@ TEST(CApiTest, Indexing_Without_Predicate_float16) {
c_load_index_info,
knowhere::Version::GetCurrentVersion().VersionNumber());
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
// load index for vec field, load raw data for scalar field
auto sealed_segment = SealedCreator(schema, dataset);
sealed_segment->DropFieldData(FieldId(100));
@ -5004,8 +5012,8 @@ TEST(CApiTest, Indexing_Without_Predicate_bfloat16) {
// insert data to segment
constexpr auto TOPK = 5;
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, DIM, VectorType::BFloat16Vector);
std::string schema_string = generate_collection_schema(
knowhere::metric::L2, DIM, VectorType::BFloat16Vector);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
CSegmentInterface segment;
@ -5032,7 +5040,8 @@ TEST(CApiTest, Indexing_Without_Predicate_bfloat16) {
milvus::proto::plan::PlanNode plan_node;
auto vector_anns = plan_node.mutable_vector_anns();
vector_anns->set_vector_type(milvus::proto::plan::VectorType::BFloat16Vector);
vector_anns->set_vector_type(
milvus::proto::plan::VectorType::BFloat16Vector);
vector_anns->set_placeholder_tag("$0");
vector_anns->set_field_id(100);
auto query_info = vector_anns->mutable_query_info();
@ -5120,7 +5129,7 @@ TEST(CApiTest, Indexing_Without_Predicate_bfloat16) {
c_load_index_info,
knowhere::Version::GetCurrentVersion().VersionNumber());
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
// load index for vec field, load raw data for scalar field
auto sealed_segment = SealedCreator(schema, dataset);
sealed_segment->DropFieldData(FieldId(100));
@ -5152,7 +5161,8 @@ TEST(CApiTest, Indexing_Without_Predicate_bfloat16) {
}
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP_FLOAT16) {
auto c_collection = NewCollection(get_float16_schema_config());
auto c_collection =
NewCollection(get_float16_schema_config(), knowhere::metric::IP);
CSegmentInterface segment;
auto status = NewSegment(c_collection, Growing, -1, &segment);
ASSERT_EQ(status.error_code, Success);
@ -5204,7 +5214,7 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP_FLOAT16) {
placeholderGroups.push_back(placeholderGroup);
CSearchResult search_result;
auto res =
auto res =
Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result);
ASSERT_EQ(res.error_code, Success);
@ -5216,7 +5226,8 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP_FLOAT16) {
}
TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP_BFLOAT16) {
auto c_collection = NewCollection(get_bfloat16_schema_config());
auto c_collection =
NewCollection(get_bfloat16_schema_config(), knowhere::metric::IP);
CSegmentInterface segment;
auto status = NewSegment(c_collection, Growing, -1, &segment);
ASSERT_EQ(status.error_code, Success);
@ -5268,7 +5279,7 @@ TEST(CApiTest, RANGE_SEARCH_WITH_RADIUS_AND_RANGE_FILTER_WHEN_IP_BFLOAT16) {
placeholderGroups.push_back(placeholderGroup);
CSearchResult search_result;
auto res =
auto res =
Search(segment, plan, placeholderGroup, {}, ts_offset, &search_result);
ASSERT_EQ(res.error_code, Success);

View File

@ -1110,9 +1110,22 @@ GenRandomIds(int rows, int64_t seed = 42) {
}
inline CCollection
NewCollection(const char* schema_proto_blob) {
NewCollection(const char* schema_proto_blob,
const MetricType metric_type = knowhere::metric::L2) {
auto proto = std::string(schema_proto_blob);
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
auto schema = collection->get_schema();
milvus::proto::segcore::CollectionIndexMeta col_index_meta;
for (auto field : schema->get_fields()) {
auto field_index_meta = col_index_meta.add_index_metas();
auto index_param = field_index_meta->add_index_params();
index_param->set_key("metric_type");
index_param->set_value(metric_type);
field_index_meta->set_fieldid(field.first.get());
}
collection->set_index_meta(
std::make_shared<CollectionIndexMeta>(col_index_meta));
return (void*)collection.release();
}

View File

@ -42,7 +42,7 @@ type SearchPlan struct {
cSearchPlan C.CSearchPlan
}
func createSearchPlanByExpr(ctx context.Context, col *Collection, expr []byte, metricType string) (*SearchPlan, error) {
func createSearchPlanByExpr(ctx context.Context, col *Collection, expr []byte) (*SearchPlan, error) {
if col.collectionPtr == nil {
return nil, errors.New("nil collection ptr, collectionID = " + fmt.Sprintln(col.id))
}
@ -88,11 +88,9 @@ type SearchRequest struct {
}
func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*SearchRequest, error) {
var err error
var plan *SearchPlan
metricType := req.GetReq().GetMetricType()
expr := req.Req.SerializedExprPlan
plan, err = createSearchPlanByExpr(ctx, collection, expr, metricType)
plan, err := createSearchPlanByExpr(ctx, collection, expr)
if err != nil {
return nil, err
}
@ -112,6 +110,12 @@ func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb.
return nil, err
}
metricTypeInPlan := plan.getMetricType()
if len(metricType) != 0 && metricType != metricTypeInPlan {
plan.delete()
return nil, merr.WrapErrParameterInvalid(metricTypeInPlan, metricType, "metric type not match")
}
var fieldID C.int64_t
status = C.GetFieldID(plan.cSearchPlan, &fieldID)
if err = HandleCStatus(ctx, &status, "get fieldID from plan failed"); err != nil {

View File

@ -58,7 +58,7 @@ func (suite *PlanSuite) TestPlanCreateByExpr() {
expr, err := proto.Marshal(planNode)
suite.NoError(err)
_, err = createSearchPlanByExpr(context.Background(), suite.collection, expr, "")
_, err = createSearchPlanByExpr(context.Background(), suite.collection, expr)
suite.Error(err)
}
@ -67,7 +67,7 @@ func (suite *PlanSuite) TestPlanFail() {
id: -1,
}
_, err := createSearchPlanByExpr(context.Background(), collection, nil, "")
_, err := createSearchPlanByExpr(context.Background(), collection, nil)
suite.Error(err)
}

View File

@ -166,7 +166,7 @@ func (suite *ReduceSuite) TestReduceAllFunc() {
proto.UnmarshalText(planStr, &planpb)
serializedPlan, err := proto.Marshal(&planpb)
suite.NoError(err)
plan, err := createSearchPlanByExpr(context.Background(), suite.collection, serializedPlan, "")
plan, err := createSearchPlanByExpr(context.Background(), suite.collection, serializedPlan)
suite.NoError(err)
searchReq, err := parseSearchRequest(context.Background(), plan, placeGroupByte)
searchReq.mvccTimestamp = typeutil.MaxTimestamp