From dc89730a5007d9eea5d3aa3e485329fcb89359ad Mon Sep 17 00:00:00 2001 From: yah01 Date: Thu, 2 Nov 2023 23:52:16 +0800 Subject: [PATCH] Support collection-level mmap control (#26901) Signed-off-by: yah01 --- internal/core/src/common/LoadInfo.h | 1 + .../core/src/segcore/SegmentSealedImpl.cpp | 2 +- internal/core/src/segcore/Types.h | 1 + .../core/src/segcore/load_field_data_c.cpp | 8 + internal/core/src/segcore/load_field_data_c.h | 5 + internal/core/src/segcore/load_index_c.cpp | 6 +- internal/core/src/segcore/load_index_c.h | 1 + internal/core/unittest/test_c_api.cpp | 36 +++-- internal/core/unittest/test_sealed.cpp | 1 + .../unittest/test_utils/storage_test_utils.h | 1 + internal/querycoordv2/job/job_test.go | 10 +- internal/querycoordv2/job/utils.go | 3 +- .../querycoordv2/meta/collection_manager.go | 2 +- .../meta/collection_manager_test.go | 20 +-- .../querycoordv2/meta/coordinator_broker.go | 7 +- .../meta/coordinator_broker_test.go | 13 +- internal/querycoordv2/meta/mock_broker.go | 114 +++++++------- .../querycoordv2/observers/leader_observer.go | 8 +- .../observers/leader_observer_test.go | 11 +- internal/querycoordv2/server_test.go | 2 +- internal/querycoordv2/services_test.go | 2 +- internal/querycoordv2/task/executor.go | 37 ++++- internal/querycoordv2/task/task_test.go | 73 +++++---- internal/querycoordv2/task/utils.go | 13 ++ internal/querycoordv2/task/utils_test.go | 140 ++++++++++++++++-- internal/querynodev2/segments/collection.go | 2 + .../segments/load_field_data_info.go | 7 + .../querynodev2/segments/load_index_info.go | 9 +- internal/querynodev2/segments/reduce_test.go | 2 +- .../querynodev2/segments/retrieve_test.go | 2 +- internal/querynodev2/segments/search_test.go | 2 +- internal/querynodev2/segments/segment.go | 6 +- .../querynodev2/segments/segment_loader.go | 42 ++++-- .../segments/segment_loader_test.go | 16 +- internal/querynodev2/segments/segment_test.go | 2 +- internal/querynodev2/services.go | 8 +- pkg/common/common.go | 30 +++- 37 files changed, 451 insertions(+), 194 deletions(-) diff --git a/internal/core/src/common/LoadInfo.h b/internal/core/src/common/LoadInfo.h index 44273ae336..968b590b98 100644 --- a/internal/core/src/common/LoadInfo.h +++ b/internal/core/src/common/LoadInfo.h @@ -29,6 +29,7 @@ struct FieldBinlogInfo { int64_t field_id; int64_t row_count = -1; std::vector entries_nums; + bool enable_mmap{false}; std::vector insert_files; }; diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 5bfc44421b..f69b46ea75 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -205,7 +205,7 @@ SegmentSealedImpl::LoadFieldData(const LoadFieldDataInfo& load_info) { "to thread pool, " << "segmentID:" << this->id_ << ", fieldID:" << info.field_id; - if (load_info.mmap_dir_path.empty() || + if (!info.enable_mmap || SystemProperty::Instance().IsSystem(field_id)) { LoadFieldData(field_id, field_data_info); } else { diff --git a/internal/core/src/segcore/Types.h b/internal/core/src/segcore/Types.h index 187671e55f..260d9848af 100644 --- a/internal/core/src/segcore/Types.h +++ b/internal/core/src/segcore/Types.h @@ -35,6 +35,7 @@ struct LoadIndexInfo { int64_t segment_id; int64_t field_id; DataType field_type; + bool enable_mmap; std::string mmap_dir_path; int64_t index_id; int64_t index_build_id; diff --git a/internal/core/src/segcore/load_field_data_c.cpp b/internal/core/src/segcore/load_field_data_c.cpp index a44e940588..485d5823cd 100644 --- a/internal/core/src/segcore/load_field_data_c.cpp +++ b/internal/core/src/segcore/load_field_data_c.cpp @@ -88,3 +88,11 @@ AppendMMapDirPath(CLoadFieldDataInfo c_load_field_data_info, static_cast(c_load_field_data_info); load_field_data_info->mmap_dir_path = std::string(c_dir_path); } + +void +EnableMmap(CLoadFieldDataInfo c_load_field_data_info, + int64_t field_id, + bool enabled) { + auto info = static_cast(c_load_field_data_info); + info->field_infos[field_id].enable_mmap = enabled; +} \ No newline at end of file diff --git a/internal/core/src/segcore/load_field_data_c.h b/internal/core/src/segcore/load_field_data_c.h index 938eae405c..ccc67d9a98 100644 --- a/internal/core/src/segcore/load_field_data_c.h +++ b/internal/core/src/segcore/load_field_data_c.h @@ -46,6 +46,11 @@ void AppendMMapDirPath(CLoadFieldDataInfo c_load_field_data_info, const char* dir_path); +void +EnableMmap(CLoadFieldDataInfo c_load_field_data_info, + int64_t field_id, + bool enabled); + #ifdef __cplusplus } #endif diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 416382be3c..a27294bd8d 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -79,6 +79,7 @@ AppendFieldInfo(CLoadIndexInfo c_load_index_info, int64_t segment_id, int64_t field_id, enum CDataType field_type, + bool enable_mmap, const char* mmap_dir_path) { try { auto load_index_info = @@ -88,6 +89,7 @@ AppendFieldInfo(CLoadIndexInfo c_load_index_info, load_index_info->segment_id = segment_id; load_index_info->field_id = field_id; load_index_info->field_type = milvus::DataType(field_type); + load_index_info->enable_mmap = enable_mmap; load_index_info->mmap_dir_path = std::string(mmap_dir_path); auto status = CStatus(); @@ -253,8 +255,10 @@ AppendIndexV2(CLoadIndexInfo c_load_index_info) { milvus::index::IndexFactory::GetInstance().CreateIndex( index_info, fileManagerContext); - if (!load_index_info->mmap_dir_path.empty() && + if (load_index_info->enable_mmap && load_index_info->index->IsMmapSupported()) { + AssertInfo(!load_index_info->mmap_dir_path.empty(), + "mmap directory path is empty"); auto filepath = std::filesystem::path(load_index_info->mmap_dir_path) / std::to_string(load_index_info->segment_id) / diff --git a/internal/core/src/segcore/load_index_c.h b/internal/core/src/segcore/load_index_c.h index 2ef8f46797..9a4dd5dc6f 100644 --- a/internal/core/src/segcore/load_index_c.h +++ b/internal/core/src/segcore/load_index_c.h @@ -41,6 +41,7 @@ AppendFieldInfo(CLoadIndexInfo c_load_index_info, int64_t segment_id, int64_t field_id, enum CDataType field_type, + bool enable_mmap, const char* mmap_dir_path); CStatus diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 91db85ef9a..c6084f5a33 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -273,6 +274,15 @@ TEST(CApiTest, CollectionTest) { DeleteCollection(collection); } +TEST(CApiTest, LoadInfoTest) { + auto load_info = std::make_shared(); + auto c_load_info = reinterpret_cast(load_info.get()); + AppendLoadFieldInfo(c_load_info, 100, 100); + EnableMmap(c_load_info, 100, true); + + EXPECT_TRUE(load_info->field_infos.at(100).enable_mmap); +} + TEST(CApiTest, SetIndexMetaTest) { auto collection = NewCollection(get_default_schema_config()); @@ -1637,7 +1647,7 @@ TEST(CApiTest, LoadIndexInfo) { ASSERT_EQ(status.error_code, Success); std::string field_name = "field0"; status = AppendFieldInfo( - c_load_index_info, 0, 0, 0, 0, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 0, CDataType::FloatVector, false, ""); ASSERT_EQ(status.error_code, Success); AppendIndexEngineVersionToLoadInfo( c_load_index_info, @@ -1799,7 +1809,7 @@ TEST(CApiTest, Indexing_Without_Predicate) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -1941,7 +1951,7 @@ TEST(CApiTest, Indexing_Expr_Without_Predicate) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -2113,7 +2123,7 @@ TEST(CApiTest, Indexing_With_float_Predicate_Range) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -2287,7 +2297,7 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Range) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -2453,7 +2463,7 @@ TEST(CApiTest, Indexing_With_float_Predicate_Term) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -2620,7 +2630,7 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Term) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -2793,7 +2803,7 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Range) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -2966,7 +2976,7 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Range) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -3133,7 +3143,7 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -3323,7 +3333,7 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -3496,7 +3506,7 @@ TEST(CApiTest, SealedSegment_search_float_Predicate_Range) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -3723,7 +3733,7 @@ TEST(CApiTest, SealedSegment_search_float_With_Expr_Predicate_Range) { AppendIndexParam( c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( - c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, false, ""); AppendIndexEngineVersionToLoadInfo( c_load_index_info, knowhere::Version::GetCurrentVersion().VersionNumber()); diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index 8c16643e1c..26a9c98bea 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -1239,6 +1239,7 @@ TEST(Sealed, GetVectorFromChunkCache) { FieldBinlogInfo{fakevec_id.get(), N, std::vector{N}, + false, std::vector{file_name}}; segment_sealed->AddFieldDataInfoForSealed(LoadFieldDataInfo{ std::map{ diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index 02a69c0b98..fd712ced5e 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -70,6 +70,7 @@ PrepareInsertBinlog(int64_t collection_id, FieldBinlogInfo{field_id, static_cast(row_count), std::vector{int64_t(row_count)}, + false, std::vector{file}}); }; diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index 8f230cb861..bd5c3dcd8e 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -126,7 +126,7 @@ func (suite *JobSuite) SetupSuite() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(vChannels, segmentBinlogs, nil).Maybe() } - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything). + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(nil, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, mock.Anything). Return(nil, nil) @@ -1237,10 +1237,10 @@ func (suite *JobSuite) TestCallLoadPartitionFailed() { // call LoadPartitions failed at get schema getSchemaErr := fmt.Errorf("mock get schema error") suite.broker.ExpectedCalls = lo.Filter(suite.broker.ExpectedCalls, func(call *mock.Call, _ int) bool { - return call.Method != "GetCollectionSchema" + return call.Method != "DescribeCollection" }) for _, collection := range suite.collections { - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, collection).Return(nil, getSchemaErr) + suite.broker.EXPECT().DescribeCollection(mock.Anything, collection).Return(nil, getSchemaErr) loadCollectionReq := &querypb.LoadCollectionRequest{ CollectionID: collection, } @@ -1280,9 +1280,9 @@ func (suite *JobSuite) TestCallLoadPartitionFailed() { } suite.broker.ExpectedCalls = lo.Filter(suite.broker.ExpectedCalls, func(call *mock.Call, _ int) bool { - return call.Method != "DescribeIndex" && call.Method != "GetCollectionSchema" + return call.Method != "DescribeIndex" && call.Method != "DescribeCollection" }) - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(nil, nil) } diff --git a/internal/querycoordv2/job/utils.go b/internal/querycoordv2/job/utils.go index c6a9b26cfc..6369dbb46a 100644 --- a/internal/querycoordv2/job/utils.go +++ b/internal/querycoordv2/job/utils.go @@ -73,10 +73,11 @@ func loadPartitions(ctx context.Context, var err error var schema *schemapb.CollectionSchema if withSchema { - schema, err = broker.GetCollectionSchema(ctx, collection) + collectionInfo, err := broker.DescribeCollection(ctx, collection) if err != nil { return err } + schema = collectionInfo.GetSchema() } indexes, err := broker.DescribeIndex(ctx, collection) if err != nil { diff --git a/internal/querycoordv2/meta/collection_manager.go b/internal/querycoordv2/meta/collection_manager.go index 43f4dcaf3e..1f1de75f9f 100644 --- a/internal/querycoordv2/meta/collection_manager.go +++ b/internal/querycoordv2/meta/collection_manager.go @@ -130,7 +130,7 @@ func (m *CollectionManager) Recover(broker Broker) error { for _, collection := range collections { // Dropped collection should be deprecated - _, err = broker.GetCollectionSchema(ctx, collection.GetCollectionID()) + _, err = broker.DescribeCollection(ctx, collection.GetCollectionID()) if errors.Is(err, merr.ErrCollectionNotFound) { ctxLog.Info("skip dropped collection during recovery", zap.Int64("collection", collection.GetCollectionID())) m.catalog.ReleaseCollection(collection.GetCollectionID()) diff --git a/internal/querycoordv2/meta/collection_manager_test.go b/internal/querycoordv2/meta/collection_manager_test.go index 311bbbe6af..2adc7b758d 100644 --- a/internal/querycoordv2/meta/collection_manager_test.go +++ b/internal/querycoordv2/meta/collection_manager_test.go @@ -178,7 +178,7 @@ func (suite *CollectionManagerSuite) TestGet() { func (suite *CollectionManagerSuite) TestUpdate() { mgr := suite.mgr - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, nil) for _, collection := range suite.collections { if len(suite.partitions[collection]) > 0 { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil) @@ -251,7 +251,7 @@ func (suite *CollectionManagerSuite) TestGetFieldIndex() { func (suite *CollectionManagerSuite) TestRemove() { mgr := suite.mgr - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, nil) for _, collection := range suite.collections { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() } @@ -322,7 +322,7 @@ func (suite *CollectionManagerSuite) TestRecover_normal() { // recover successfully for _, collection := range suite.collections { - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, collection).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, collection).Return(nil, nil) if len(suite.partitions[collection]) > 0 { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil) } @@ -342,7 +342,7 @@ func (suite *CollectionManagerSuite) TestRecover_normal() { func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() { mgr := suite.mgr suite.releaseAll() - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, nil) // test put collection with partitions for i, collection := range suite.collections { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() @@ -432,9 +432,9 @@ func (suite *CollectionManagerSuite) TestRecover_with_dropped() { for _, collection := range suite.collections { if collection == droppedCollection { - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, collection).Return(nil, merr.ErrCollectionNotFound) + suite.broker.EXPECT().DescribeCollection(mock.Anything, collection).Return(nil, merr.ErrCollectionNotFound) } else { - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, collection).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, collection).Return(nil, nil) } if len(suite.partitions[collection]) != 0 { if collection == droppedCollection { @@ -465,8 +465,8 @@ func (suite *CollectionManagerSuite) TestRecover_with_dropped() { } func (suite *CollectionManagerSuite) TestRecover_Failed() { - mockErr1 := fmt.Errorf("mock GetCollectionSchema err") - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, mockErr1) + mockErr1 := fmt.Errorf("mock.DescribeCollection err") + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, mockErr1) suite.clearMemory() err := suite.mgr.Recover(suite.broker) suite.Error(err) @@ -474,7 +474,7 @@ func (suite *CollectionManagerSuite) TestRecover_Failed() { mockErr2 := fmt.Errorf("mock GetPartitions err") suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0] - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, nil) suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return(nil, mockErr2) suite.clearMemory() err = suite.mgr.Recover(suite.broker) @@ -539,7 +539,7 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() { suite.releaseAll() mgr := suite.mgr - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, nil) for _, collection := range suite.collections { if len(suite.partitions[collection]) > 0 { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil) diff --git a/internal/querycoordv2/meta/coordinator_broker.go b/internal/querycoordv2/meta/coordinator_broker.go index d85b2a0224..9d1709199b 100644 --- a/internal/querycoordv2/meta/coordinator_broker.go +++ b/internal/querycoordv2/meta/coordinator_broker.go @@ -25,7 +25,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" @@ -40,7 +39,7 @@ import ( ) type Broker interface { - GetCollectionSchema(ctx context.Context, collectionID UniqueID) (*schemapb.CollectionSchema, error) + DescribeCollection(ctx context.Context, collectionID UniqueID) (*milvuspb.DescribeCollectionResponse, error) GetPartitions(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) GetRecoveryInfo(ctx context.Context, collectionID UniqueID, partitionID UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentBinlogs, error) DescribeIndex(ctx context.Context, collectionID UniqueID) ([]*indexpb.IndexInfo, error) @@ -64,7 +63,7 @@ func NewCoordinatorBroker( } } -func (broker *CoordinatorBroker) GetCollectionSchema(ctx context.Context, collectionID UniqueID) (*schemapb.CollectionSchema, error) { +func (broker *CoordinatorBroker) DescribeCollection(ctx context.Context, collectionID UniqueID) (*milvuspb.DescribeCollectionResponse, error) { ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() @@ -80,7 +79,7 @@ func (broker *CoordinatorBroker) GetCollectionSchema(ctx context.Context, collec log.Ctx(ctx).Warn("failed to get collection schema", zap.Error(err)) return nil, err } - return resp.GetSchema(), nil + return resp, nil } func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) { diff --git a/internal/querycoordv2/meta/coordinator_broker_test.go b/internal/querycoordv2/meta/coordinator_broker_test.go index a5e7f8abd5..98330e7f2b 100644 --- a/internal/querycoordv2/meta/coordinator_broker_test.go +++ b/internal/querycoordv2/meta/coordinator_broker_test.go @@ -66,13 +66,14 @@ func (s *CoordinatorBrokerRootCoordSuite) TestGetCollectionSchema() { s.Run("normal case", func() { s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - Schema: &schemapb.CollectionSchema{Name: "test_schema"}, + Status: merr.Success(), + Schema: &schemapb.CollectionSchema{Name: "test_schema"}, + CollectionName: "test_schema", }, nil) - schema, err := s.broker.GetCollectionSchema(ctx, collectionID) + resp, err := s.broker.DescribeCollection(ctx, collectionID) s.NoError(err) - s.Equal("test_schema", schema.GetName()) + s.Equal("test_schema", resp.GetCollectionName()) s.resetMock() }) @@ -80,7 +81,7 @@ func (s *CoordinatorBrokerRootCoordSuite) TestGetCollectionSchema() { s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(nil, errors.New("mock error")) - _, err := s.broker.GetCollectionSchema(ctx, collectionID) + _, err := s.broker.DescribeCollection(ctx, collectionID) s.Error(err) s.resetMock() }) @@ -91,7 +92,7 @@ func (s *CoordinatorBrokerRootCoordSuite) TestGetCollectionSchema() { Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}, }, nil) - _, err := s.broker.GetCollectionSchema(ctx, collectionID) + _, err := s.broker.DescribeCollection(ctx, collectionID) s.Error(err) s.resetMock() }) diff --git a/internal/querycoordv2/meta/mock_broker.go b/internal/querycoordv2/meta/mock_broker.go index a05bb3692d..9ba70eb8c4 100644 --- a/internal/querycoordv2/meta/mock_broker.go +++ b/internal/querycoordv2/meta/mock_broker.go @@ -8,11 +8,11 @@ import ( datapb "github.com/milvus-io/milvus/internal/proto/datapb" indexpb "github.com/milvus-io/milvus/internal/proto/indexpb" + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + mock "github.com/stretchr/testify/mock" querypb "github.com/milvus-io/milvus/internal/proto/querypb" - - schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) // MockBroker is an autogenerated mock type for the Broker type @@ -28,6 +28,61 @@ func (_m *MockBroker) EXPECT() *MockBroker_Expecter { return &MockBroker_Expecter{mock: &_m.Mock} } +// DescribeCollection provides a mock function with given fields: ctx, collectionID +func (_m *MockBroker) DescribeCollection(ctx context.Context, collectionID int64) (*milvuspb.DescribeCollectionResponse, error) { + ret := _m.Called(ctx, collectionID) + + var r0 *milvuspb.DescribeCollectionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (*milvuspb.DescribeCollectionResponse, error)); ok { + return rf(ctx, collectionID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) *milvuspb.DescribeCollectionResponse); ok { + r0 = rf(ctx, collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_DescribeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeCollection' +type MockBroker_DescribeCollection_Call struct { + *mock.Call +} + +// DescribeCollection is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +func (_e *MockBroker_Expecter) DescribeCollection(ctx interface{}, collectionID interface{}) *MockBroker_DescribeCollection_Call { + return &MockBroker_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", ctx, collectionID)} +} + +func (_c *MockBroker_DescribeCollection_Call) Run(run func(ctx context.Context, collectionID int64)) *MockBroker_DescribeCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockBroker_DescribeCollection_Call) Return(_a0 *milvuspb.DescribeCollectionResponse, _a1 error) *MockBroker_DescribeCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_DescribeCollection_Call) RunAndReturn(run func(context.Context, int64) (*milvuspb.DescribeCollectionResponse, error)) *MockBroker_DescribeCollection_Call { + _c.Call.Return(run) + return _c +} + // DescribeIndex provides a mock function with given fields: ctx, collectionID func (_m *MockBroker) DescribeIndex(ctx context.Context, collectionID int64) ([]*indexpb.IndexInfo, error) { ret := _m.Called(ctx, collectionID) @@ -83,61 +138,6 @@ func (_c *MockBroker_DescribeIndex_Call) RunAndReturn(run func(context.Context, return _c } -// GetCollectionSchema provides a mock function with given fields: ctx, collectionID -func (_m *MockBroker) GetCollectionSchema(ctx context.Context, collectionID int64) (*schemapb.CollectionSchema, error) { - ret := _m.Called(ctx, collectionID) - - var r0 *schemapb.CollectionSchema - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64) (*schemapb.CollectionSchema, error)); ok { - return rf(ctx, collectionID) - } - if rf, ok := ret.Get(0).(func(context.Context, int64) *schemapb.CollectionSchema); ok { - r0 = rf(ctx, collectionID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*schemapb.CollectionSchema) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { - r1 = rf(ctx, collectionID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockBroker_GetCollectionSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionSchema' -type MockBroker_GetCollectionSchema_Call struct { - *mock.Call -} - -// GetCollectionSchema is a helper method to define mock.On call -// - ctx context.Context -// - collectionID int64 -func (_e *MockBroker_Expecter) GetCollectionSchema(ctx interface{}, collectionID interface{}) *MockBroker_GetCollectionSchema_Call { - return &MockBroker_GetCollectionSchema_Call{Call: _e.mock.On("GetCollectionSchema", ctx, collectionID)} -} - -func (_c *MockBroker_GetCollectionSchema_Call) Run(run func(ctx context.Context, collectionID int64)) *MockBroker_GetCollectionSchema_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int64)) - }) - return _c -} - -func (_c *MockBroker_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSchema, _a1 error) *MockBroker_GetCollectionSchema_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockBroker_GetCollectionSchema_Call) RunAndReturn(run func(context.Context, int64) (*schemapb.CollectionSchema, error)) *MockBroker_GetCollectionSchema_Call { - _c.Call.Return(run) - return _c -} - // GetIndexInfo provides a mock function with given fields: ctx, collectionID, segmentID func (_m *MockBroker) GetIndexInfo(ctx context.Context, collectionID int64, segmentID int64) ([]*querypb.FieldIndexInfo, error) { ret := _m.Called(ctx, collectionID, segmentID) diff --git a/internal/querycoordv2/observers/leader_observer.go b/internal/querycoordv2/observers/leader_observer.go index 5778a29766..40e87d0972 100644 --- a/internal/querycoordv2/observers/leader_observer.go +++ b/internal/querycoordv2/observers/leader_observer.go @@ -277,14 +277,14 @@ func (o *LeaderObserver) sync(ctx context.Context, replicaID int64, leaderView * zap.String("channel", leaderView.Channel), ) - schema, err := o.broker.GetCollectionSchema(ctx, leaderView.CollectionID) + collectionInfo, err := o.broker.DescribeCollection(ctx, leaderView.CollectionID) if err != nil { - log.Warn("sync distribution failed, cannot get schema of collection", zap.Error(err)) + log.Warn("failed to get collection info", zap.Error(err)) return false } partitions, err := utils.GetPartitions(o.meta.CollectionManager, leaderView.CollectionID) if err != nil { - log.Warn("sync distribution failed, cannot get partitions of collection", zap.Error(err)) + log.Warn("failed to get partitions", zap.Error(err)) return false } @@ -296,7 +296,7 @@ func (o *LeaderObserver) sync(ctx context.Context, replicaID int64, leaderView * ReplicaID: replicaID, Channel: leaderView.Channel, Actions: diffs, - Schema: schema, + Schema: collectionInfo.GetSchema(), LoadMeta: &querypb.LoadMetaInfo{ LoadType: o.meta.GetLoadType(leaderView.CollectionID), CollectionID: leaderView.CollectionID, diff --git a/internal/querycoordv2/observers/leader_observer_test.go b/internal/querycoordv2/observers/leader_observer_test.go index 3a2738f4ff..44f0e2a28a 100644 --- a/internal/querycoordv2/observers/leader_observer_test.go +++ b/internal/querycoordv2/observers/leader_observer_test.go @@ -27,6 +27,7 @@ import ( "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" @@ -116,7 +117,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() { Infos: []*datapb.SegmentInfo{info}, } schema := utils.CreateTestSchema() - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, int64(1)).Return(&milvuspb.DescribeCollectionResponse{Schema: schema}, nil) suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return( &datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( @@ -197,7 +198,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() { }, } schema := utils.CreateTestSchema() - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, int64(1)).Return(&milvuspb.DescribeCollectionResponse{Schema: schema}, nil) info := &datapb.SegmentInfo{ ID: 1, CollectionID: 1, @@ -342,7 +343,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() { &datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, int64(1)).Return(&milvuspb.DescribeCollectionResponse{Schema: schema}, nil) observer.target.UpdateCollectionNextTarget(int64(1)) observer.target.UpdateCollectionCurrentTarget(1) observer.dist.SegmentDistManager.Update(1, utils.CreateTestSegment(1, 1, 1, 1, 1, "test-insert-channel")) @@ -410,7 +411,7 @@ func (suite *LeaderObserverTestSuite) TestSyncRemovedSegments() { observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) schema := utils.CreateTestSchema() - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, int64(1)).Return(&milvuspb.DescribeCollectionResponse{Schema: schema}, nil) channels := []*datapb.VchannelInfo{ { @@ -490,7 +491,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() { }, } schema := utils.CreateTestSchema() - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) + suite.broker.EXPECT().DescribeCollection(mock.Anything, int64(1)).Return(&milvuspb.DescribeCollectionResponse{Schema: schema}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return( channels, segments, nil) observer.target.UpdateCollectionNextTarget(int64(1)) diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 5a0ddecda7..2f3ed26eaf 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -523,7 +523,7 @@ func (suite *ServerSuite) hackServer() { suite.server.checkerController, ) - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe() + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe() suite.broker.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(nil, nil).Maybe() for _, collection := range suite.collections { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 80be1701d8..6dba44acad 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -1735,7 +1735,7 @@ func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) { } func (suite *ServiceSuite) expectLoadPartitions() { - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything). + suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(nil, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, mock.Anything). Return(nil, nil) diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index b5e2b508a3..461b0ad4d5 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -242,9 +242,9 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { } }() - schema, err := ex.broker.GetCollectionSchema(ctx, task.CollectionID()) + collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID()) if err != nil { - log.Warn("failed to get schema of collection", zap.Error(err)) + log.Warn("failed to get collection info", zap.Error(err)) return err } partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID()) @@ -277,7 +277,13 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { loadInfo := utils.PackSegmentLoadInfo(resp, indexes) // Get shard leader for the given replica and segment - leader, ok := getShardLeader(ex.meta.ReplicaManager, ex.dist, task.CollectionID(), action.Node(), segment.GetInsertChannel()) + leader, ok := getShardLeader( + ex.meta.ReplicaManager, + ex.dist, + task.CollectionID(), + action.Node(), + segment.GetInsertChannel(), + ) if !ok { msg := "no shard leader for the segment to execute loading" err = merr.WrapErrChannelNotFound(segment.GetInsertChannel(), "shard delegator not found") @@ -293,7 +299,15 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { return err } - req := packLoadSegmentRequest(task, action, schema, loadMeta, loadInfo, indexInfo) + req := packLoadSegmentRequest( + task, + action, + collectionInfo.GetSchema(), + collectionInfo.GetProperties(), + loadMeta, + loadInfo, + indexInfo, + ) loadTask := NewLoadSegmentsTask(task, step, req) ex.merger.Add(loadTask) log.Info("load segment task committed") @@ -396,9 +410,9 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { ctx := task.Context() - schema, err := ex.broker.GetCollectionSchema(ctx, task.CollectionID()) + collectionInfo, err := ex.broker.DescribeCollection(ctx, task.CollectionID()) if err != nil { - log.Warn("failed to get schema of collection") + log.Warn("failed to get collection info") return err } partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID()) @@ -411,7 +425,7 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { log.Warn("fail to get index meta of collection") return err } - metricType, err := getMetricType(indexInfo, schema) + metricType, err := getMetricType(indexInfo, collectionInfo.GetSchema()) if err != nil { log.Warn("failed to get metric type", zap.Error(err)) return err @@ -429,7 +443,14 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { log.Warn(msg, zap.String("channelName", action.ChannelName())) return merr.WrapErrChannelReduplicate(action.ChannelName()) } - req := packSubChannelRequest(task, action, schema, loadMeta, dmChannel, indexInfo) + req := packSubChannelRequest( + task, + action, + collectionInfo.GetSchema(), + loadMeta, + dmChannel, + indexInfo, + ) err = fillSubChannelRequest(ctx, req, ex.broker) if err != nil { log.Warn("failed to subscribe channel, failed to fill the request with segments", diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 601e586e24..f2aef18d26 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -28,6 +28,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" @@ -196,11 +197,13 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { partitions := []int64{100, 101} // Expect - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection). - Return(&schemapb.CollectionSchema{ - Name: "TestSubscribeChannelTask", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection). + Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestSubscribeChannelTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, }, }, nil) for channel, segment := range suite.growingSegments { @@ -384,10 +387,12 @@ func (suite *TaskSuite) TestLoadSegmentTask() { } // Expect - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ - Name: "TestLoadSegmentTask", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestLoadSegmentTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, }, }, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ @@ -480,10 +485,12 @@ func (suite *TaskSuite) TestLoadSegmentTaskNotIndex() { } // Expect - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ - Name: "TestLoadSegmentTask", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestLoadSegmentTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, }, }, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ @@ -576,10 +583,12 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() { } // Expect - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ - Name: "TestLoadSegmentTask", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestLoadSegmentTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, }, }, nil) for _, segment := range suite.loadSegments { @@ -774,10 +783,12 @@ func (suite *TaskSuite) TestMoveSegmentTask() { } // Expect - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ - Name: "TestMoveSegmentTask", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestMoveSegmentTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, }, }, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ @@ -944,10 +955,12 @@ func (suite *TaskSuite) TestTaskCanceled() { } // Expect - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ - Name: "TestSubscribeChannelTask", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestSubscribeChannelTask", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, }, }, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ @@ -1031,10 +1044,12 @@ func (suite *TaskSuite) TestSegmentTaskStale() { } // Expect - suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ - Name: "TestSegmentTaskStale", - Fields: []*schemapb.FieldSchema{ - {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + suite.broker.EXPECT().DescribeCollection(mock.Anything, suite.collection).Return(&milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Name: "TestSegmentTaskStale", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "vec", DataType: schemapb.DataType_FloatVector}, + }, }, }, nil) suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index f9ee116745..e895f6af5d 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -102,6 +102,7 @@ func packLoadSegmentRequest( task *SegmentTask, action Action, schema *schemapb.CollectionSchema, + collectionProperties []*commonpb.KeyValuePair, loadMeta *querypb.LoadMetaInfo, loadInfo *querypb.SegmentLoadInfo, indexInfo []*indexpb.IndexInfo, @@ -110,6 +111,18 @@ func packLoadSegmentRequest( if action.Type() == ActionTypeUpdate { loadScope = querypb.LoadScope_Index } + + // field mmap enabled if collection-level mmap enabled or the field mmap enabled + collectionMmapEnabled := common.IsMmapEnabled(collectionProperties...) + for _, field := range schema.GetFields() { + if collectionMmapEnabled { + field.TypeParams = append(field.TypeParams, &commonpb.KeyValuePair{ + Key: common.MmapEnabledKey, + Value: "true", + }) + } + } + return &querypb.LoadSegmentsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments), diff --git a/internal/querycoordv2/task/utils_test.go b/internal/querycoordv2/task/utils_test.go index 788cadc471..bd685344a0 100644 --- a/internal/querycoordv2/task/utils_test.go +++ b/internal/querycoordv2/task/utils_test.go @@ -17,17 +17,25 @@ package task import ( + "context" "testing" + "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/common" ) -func Test_getMetricType(t *testing.T) { +type UtilsSuite struct { + suite.Suite +} + +func (s *UtilsSuite) TestGetMetricType() { collection := int64(1) schema := &schemapb.CollectionSchema{ Name: "TestGetMetricType", @@ -51,29 +59,139 @@ func Test_getMetricType(t *testing.T) { FieldID: 100, } - t.Run("test normal", func(t *testing.T) { + s.Run("test normal", func() { metricType, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, schema) - assert.NoError(t, err) - assert.Equal(t, "L2", metricType) + s.NoError(err) + s.Equal("L2", metricType) }) - t.Run("test get vec field failed", func(t *testing.T) { + s.Run("test get vec field failed", func() { _, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{ Name: "TestGetMetricType", }) - assert.Error(t, err) + s.Error(err) }) - t.Run("test field id mismatch", func(t *testing.T) { + s.Run("test field id mismatch", func() { _, err := getMetricType([]*indexpb.IndexInfo{indexInfo}, &schemapb.CollectionSchema{ Name: "TestGetMetricType", Fields: []*schemapb.FieldSchema{ {FieldID: -1, Name: "vec", DataType: schemapb.DataType_FloatVector}, }, }) - assert.Error(t, err) + s.Error(err) }) - t.Run("test no metric type", func(t *testing.T) { + s.Run("test no metric type", func() { _, err := getMetricType([]*indexpb.IndexInfo{indexInfo2}, schema) - assert.Error(t, err) + s.Error(err) }) } + +func (s *UtilsSuite) TestPackLoadSegmentRequest() { + ctx := context.Background() + + action := NewSegmentAction(1, ActionTypeGrow, "test-ch", 100) + task, err := NewSegmentTask( + ctx, + time.Second, + nil, + 1, + 10, + action, + ) + s.NoError(err) + + collectionInfoResp := &milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + }, + }, + }, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.MmapEnabledKey, + Value: "false", + }, + }, + } + + req := packLoadSegmentRequest( + task, + action, + collectionInfoResp.GetSchema(), + collectionInfoResp.GetProperties(), + &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + }, + &querypb.SegmentLoadInfo{}, + nil, + ) + + s.True(req.GetNeedTransfer()) + s.Equal(task.CollectionID(), req.CollectionID) + s.Equal(task.ReplicaID(), req.ReplicaID) + s.Equal(action.Node(), req.GetDstNodeID()) + for _, field := range req.GetSchema().GetFields() { + s.False(common.IsMmapEnabled(field.GetTypeParams()...)) + } +} + +func (s *UtilsSuite) TestPackLoadSegmentRequestMmap() { + ctx := context.Background() + + action := NewSegmentAction(1, ActionTypeGrow, "test-ch", 100) + task, err := NewSegmentTask( + ctx, + time.Second, + nil, + 1, + 10, + action, + ) + s.NoError(err) + + collectionInfoResp := &milvuspb.DescribeCollectionResponse{ + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + }, + }, + }, + Properties: []*commonpb.KeyValuePair{ + { + Key: common.MmapEnabledKey, + Value: "true", + }, + }, + } + + req := packLoadSegmentRequest( + task, + action, + collectionInfoResp.GetSchema(), + collectionInfoResp.GetProperties(), + &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + }, + &querypb.SegmentLoadInfo{}, + nil, + ) + + s.True(req.GetNeedTransfer()) + s.Equal(task.CollectionID(), req.CollectionID) + s.Equal(task.ReplicaID(), req.ReplicaID) + s.Equal(action.Node(), req.GetDstNodeID()) + for _, field := range req.GetSchema().GetFields() { + s.True(common.IsMmapEnabled(field.GetTypeParams()...)) + } +} + +func TestUtils(t *testing.T) { + suite.Run(t, new(UtilsSuite)) +} diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index b8240e1336..9b61e1dd6b 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -72,6 +72,8 @@ func (m *collectionManager) PutOrRef(collectionID int64, schema *schemapb.Collec defer m.mut.Unlock() if collection, ok := m.collections[collectionID]; ok { + // the schema may be changed even the collection is loaded + collection.schema = schema collection.Ref(1) return } diff --git a/internal/querynodev2/segments/load_field_data_info.go b/internal/querynodev2/segments/load_field_data_info.go index 5904457853..44b349d44e 100644 --- a/internal/querynodev2/segments/load_field_data_info.go +++ b/internal/querynodev2/segments/load_field_data_info.go @@ -64,6 +64,13 @@ func (ld *LoadFieldDataInfo) appendLoadFieldDataPath(fieldID int64, binlog *data return HandleCStatus(&status, "appendLoadFieldDataPath failed") } +func (ld *LoadFieldDataInfo) enableMmap(fieldID int64, enabled bool) { + cFieldID := C.int64_t(fieldID) + cEnabled := C.bool(enabled) + + C.EnableMmap(ld.cLoadFieldDataInfo, cFieldID, cEnabled) +} + func (ld *LoadFieldDataInfo) appendMMapDirPath(dir string) { cDir := C.CString(dir) defer C.free(unsafe.Pointer(cDir)) diff --git a/internal/querynodev2/segments/load_index_info.go b/internal/querynodev2/segments/load_index_info.go index 3cb008d1c6..81f4b833ca 100644 --- a/internal/querynodev2/segments/load_index_info.go +++ b/internal/querynodev2/segments/load_index_info.go @@ -56,12 +56,12 @@ func deleteLoadIndexInfo(info *LoadIndexInfo) { C.DeleteLoadIndexInfo(info.cLoadIndexInfo) } -func (li *LoadIndexInfo) appendLoadIndexInfo(indexInfo *querypb.FieldIndexInfo, collectionID int64, partitionID int64, segmentID int64, fieldType schemapb.DataType) error { +func (li *LoadIndexInfo) appendLoadIndexInfo(indexInfo *querypb.FieldIndexInfo, collectionID int64, partitionID int64, segmentID int64, fieldType schemapb.DataType, enableMmap bool) error { fieldID := indexInfo.FieldID indexPaths := indexInfo.IndexFilePaths mmapDirPath := paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue() - err := li.appendFieldInfo(collectionID, partitionID, segmentID, fieldID, fieldType, mmapDirPath) + err := li.appendFieldInfo(collectionID, partitionID, segmentID, fieldID, fieldType, enableMmap, mmapDirPath) if err != nil { return err } @@ -133,15 +133,16 @@ func (li *LoadIndexInfo) appendIndexFile(filePath string) error { } // appendFieldInfo appends fieldID & fieldType to index -func (li *LoadIndexInfo) appendFieldInfo(collectionID int64, partitionID int64, segmentID int64, fieldID int64, fieldType schemapb.DataType, mmapDirPath string) error { +func (li *LoadIndexInfo) appendFieldInfo(collectionID int64, partitionID int64, segmentID int64, fieldID int64, fieldType schemapb.DataType, enableMmap bool, mmapDirPath string) error { cColID := C.int64_t(collectionID) cParID := C.int64_t(partitionID) cSegID := C.int64_t(segmentID) cFieldID := C.int64_t(fieldID) cintDType := uint32(fieldType) + cEnableMmap := C.bool(enableMmap) cMmapDirPath := C.CString(mmapDirPath) defer C.free(unsafe.Pointer(cMmapDirPath)) - status := C.AppendFieldInfo(li.cLoadIndexInfo, cColID, cParID, cSegID, cFieldID, cintDType, cMmapDirPath) + status := C.AppendFieldInfo(li.cLoadIndexInfo, cColID, cParID, cSegID, cFieldID, cintDType, cEnableMmap, cMmapDirPath) return HandleCStatus(&status, "AppendFieldInfo failed") } diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/querynodev2/segments/reduce_test.go index 573a3dfe6f..35a653e253 100644 --- a/internal/querynodev2/segments/reduce_test.go +++ b/internal/querynodev2/segments/reduce_test.go @@ -93,7 +93,7 @@ func (suite *ReduceSuite) SetupTest() { ) suite.Require().NoError(err) for _, binlog := range binlogs { - err = suite.segment.LoadFieldData(binlog.FieldID, int64(msgLength), binlog) + err = suite.segment.LoadFieldData(binlog.FieldID, int64(msgLength), binlog, false) suite.Require().NoError(err) } } diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index c356948b52..d66ad30a19 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -104,7 +104,7 @@ func (suite *RetrieveSuite) SetupTest() { ) suite.Require().NoError(err) for _, binlog := range binlogs { - err = suite.sealed.LoadFieldData(binlog.FieldID, int64(msgLength), binlog) + err = suite.sealed.LoadFieldData(binlog.FieldID, int64(msgLength), binlog, false) suite.Require().NoError(err) } diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go index 9c7d257f55..d9b8723a18 100644 --- a/internal/querynodev2/segments/search_test.go +++ b/internal/querynodev2/segments/search_test.go @@ -95,7 +95,7 @@ func (suite *SearchSuite) SetupTest() { ) suite.Require().NoError(err) for _, binlog := range binlogs { - err = suite.sealed.LoadFieldData(binlog.FieldID, int64(msgLength), binlog) + err = suite.sealed.LoadFieldData(binlog.FieldID, int64(msgLength), binlog, false) suite.Require().NoError(err) } diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 288994c77f..ae6193173d 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -655,7 +655,7 @@ func (s *LocalSegment) LoadMultiFieldData(rowCount int64, fields []*datapb.Field return nil } -func (s *LocalSegment) LoadFieldData(fieldID int64, rowCount int64, field *datapb.FieldBinlog) error { +func (s *LocalSegment) LoadFieldData(fieldID int64, rowCount int64, field *datapb.FieldBinlog, mmapEnabled bool) error { s.ptrLock.RLock() defer s.ptrLock.RUnlock() @@ -834,14 +834,14 @@ func (s *LocalSegment) LoadDeltaData(deltaData *storage.DeleteData) error { return nil } -func (s *LocalSegment) LoadIndex(indexInfo *querypb.FieldIndexInfo, fieldType schemapb.DataType) error { +func (s *LocalSegment) LoadIndex(indexInfo *querypb.FieldIndexInfo, fieldType schemapb.DataType, enableMmap bool) error { loadIndexInfo, err := newLoadIndexInfo() defer deleteLoadIndexInfo(loadIndexInfo) if err != nil { return err } - err = loadIndexInfo.appendLoadIndexInfo(indexInfo, s.collectionID, s.partitionID, s.segmentID, fieldType) + err = loadIndexInfo.appendLoadIndexInfo(indexInfo, s.collectionID, s.partitionID, s.segmentID, fieldType, enableMmap) if err != nil { if loadIndexInfo.cleanLocalData() != nil { log.Warn("failed to clean cached data on disk after append index failed", diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 283946f8d3..90a1b4c279 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -645,12 +645,21 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi } func (loader *segmentLoader) loadSealedSegmentFields(ctx context.Context, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error { + collection := loader.manager.Collection.Get(segment.Collection()) + if collection == nil { + return merr.WrapErrCollectionNotLoaded(segment.Collection(), "failed to load segment fields") + } + runningGroup, _ := errgroup.WithContext(ctx) for _, field := range fields { fieldBinLog := field fieldID := field.FieldID runningGroup.Go(func() error { - return segment.LoadFieldData(fieldID, rowCount, fieldBinLog) + return segment.LoadFieldData(fieldID, + rowCount, + fieldBinLog, + common.IsFieldMmapEnabled(collection.Schema(), fieldID), + ) }) } err := runningGroup.Wait() @@ -717,14 +726,18 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS } } - // 2. use index path to update segment indexInfo.IndexFilePaths = filteredPaths fieldType, err := loader.getFieldType(segment.Collection(), indexInfo.FieldID) if err != nil { return err } - return segment.LoadIndex(indexInfo, fieldType) + collection := loader.manager.Collection.Get(segment.Collection()) + if collection == nil { + return merr.WrapErrCollectionNotLoaded(segment.Collection(), "failed to load field index") + } + + return segment.LoadIndex(indexInfo, fieldType, common.IsFieldMmapEnabled(collection.Schema(), indexInfo.GetFieldID())) } func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet, @@ -932,11 +945,13 @@ func (loader *segmentLoader) checkSegmentSize(ctx context.Context, segmentLoadIn metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(toMB(uint64(localDiskUsage))) diskUsage := uint64(localDiskUsage) + loader.committedResource.DiskSize - mmapEnabled := len(paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue()) > 0 maxSegmentSize := uint64(0) predictMemUsage := memUsage predictDiskUsage := diskUsage + mmapFieldCount := 0 for _, loadInfo := range segmentLoadInfos { + collection := loader.manager.Collection.Get(loadInfo.GetCollectionID()) + oldUsedMem := predictMemUsage vecFieldID2IndexInfo := make(map[int64]*querypb.FieldIndexInfo) for _, fieldIndexInfo := range loadInfo.IndexInfos { @@ -948,6 +963,7 @@ func (loader *segmentLoader) checkSegmentSize(ctx context.Context, segmentLoadIn for _, fieldBinlog := range loadInfo.BinlogPaths { fieldID := fieldBinlog.FieldID + mmapEnabled := common.IsFieldMmapEnabled(collection.Schema(), fieldID) if fieldIndexInfo, ok := vecFieldID2IndexInfo[fieldID]; ok { neededMemSize, neededDiskSize, err := GetIndexResourceUsage(fieldIndexInfo) if err != nil { @@ -972,6 +988,10 @@ func (loader *segmentLoader) checkSegmentSize(ctx context.Context, segmentLoadIn predictMemUsage += uint64(getBinlogDataSize(fieldBinlog)) } } + + if mmapEnabled { + mmapFieldCount++ + } } // get size of stats data @@ -998,20 +1018,10 @@ func (loader *segmentLoader) checkSegmentSize(ctx context.Context, segmentLoadIn zap.Float64("diskUsage", toMB(diskUsage)), zap.Float64("predictMemUsage", toMB(predictMemUsage)), zap.Float64("predictDiskUsage", toMB(predictDiskUsage)), - zap.Bool("mmapEnabled", mmapEnabled), + zap.Int("mmapFieldCount", mmapFieldCount), ) - if !mmapEnabled && predictMemUsage > uint64(float64(totalMem)*paramtable.Get().QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) { - return 0, 0, fmt.Errorf("load segment failed, OOM if load, maxSegmentSize = %v MB, concurrency = %d, memUsage = %v MB, predictMemUsage = %v MB, totalMem = %v MB thresholdFactor = %f", - toMB(maxSegmentSize), - concurrency, - toMB(memUsage), - toMB(predictMemUsage), - toMB(totalMem), - paramtable.Get().QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) - } - - if mmapEnabled && memUsage > uint64(float64(totalMem)*paramtable.Get().QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) { + if predictMemUsage > uint64(float64(totalMem)*paramtable.Get().QueryNodeCfg.OverloadedMemoryThresholdPercentage.GetAsFloat()) { return 0, 0, fmt.Errorf("load segment failed, OOM if load, maxSegmentSize = %v MB, concurrency = %d, memUsage = %v MB, predictMemUsage = %v MB, totalMem = %v MB thresholdFactor = %f", toMB(maxSegmentSize), concurrency, diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index 31524f7651..ab105808a0 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" @@ -60,7 +61,6 @@ func (suite *SegmentLoaderSuite) SetupSuite() { suite.partitionID = rand.Int63() suite.segmentID = rand.Int63() suite.segmentNum = 5 - suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64) } func (suite *SegmentLoaderSuite) SetupTest() { @@ -76,14 +76,14 @@ func (suite *SegmentLoaderSuite) SetupTest() { initcore.InitRemoteChunkManager(paramtable.Get()) // Data - schema := GenTestCollectionSchema("test", schemapb.DataType_Int64) - indexMeta := GenTestIndexMeta(suite.collectionID, schema) + suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64) + indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema) loadMeta := &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, PartitionIDs: []int64{suite.partitionID}, } - suite.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, loadMeta) + suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta) } func (suite *SegmentLoaderSuite) TearDownTest() { @@ -439,6 +439,14 @@ func (suite *SegmentLoaderSuite) TestLoadWithMmap() { defer paramtable.Get().Reset(key) ctx := context.Background() + collection := suite.manager.Collection.Get(suite.collectionID) + for _, field := range collection.Schema().GetFields() { + field.TypeParams = append(field.TypeParams, &commonpb.KeyValuePair{ + Key: common.MmapEnabledKey, + Value: "true", + }) + } + msgLength := 100 // Load sealed binlogs, statsLogs, err := SaveBinLog(ctx, diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 6c8b90e3e6..87fd9a476d 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -82,7 +82,7 @@ func (suite *SegmentSuite) SetupTest() { ) suite.Require().NoError(err) for _, binlog := range binlogs { - err = suite.sealed.LoadFieldData(binlog.FieldID, int64(msgLength), binlog) + err = suite.sealed.LoadFieldData(binlog.FieldID, int64(msgLength), binlog, false) suite.Require().NoError(err) } diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 1337d3fe12..1e24fabccd 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -447,6 +447,10 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen return merr.Success(), nil } + node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), + node.composeIndexMeta(req.GetIndexInfoList(), req.GetSchema()), req.GetLoadMeta()) + defer node.manager.Collection.Unref(req.GetCollectionID(), 1) + if req.GetLoadScope() == querypb.LoadScope_Delta { return node.loadDeltaLogs(ctx, req), nil } @@ -454,10 +458,6 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen return node.loadIndex(ctx, req), nil } - node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), - node.composeIndexMeta(req.GetIndexInfoList(), req.GetSchema()), req.GetLoadMeta()) - defer node.manager.Collection.Unref(req.GetCollectionID(), 1) - // Actual load segment log.Info("start to load segments...") loaded, err := node.loader.Load(ctx, diff --git a/pkg/common/common.go b/pkg/common/common.go index 1ec3bfe3bd..946611e5dc 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -16,7 +16,12 @@ package common -import "encoding/binary" +import ( + "encoding/binary" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) // system field id: // 0: unique row id @@ -119,6 +124,11 @@ const ( CollectionDiskQuotaKey = "collection.diskProtection.diskQuota.mb" ) +// common properties +const ( + MmapEnabledKey = "mmap.enabled" +) + const ( PropertiesKey string = "properties" TraceIDKey string = "uber-trace-id" @@ -128,6 +138,24 @@ func IsSystemField(fieldID int64) bool { return fieldID < StartOfUserFieldID } +func IsMmapEnabled(kvs ...*commonpb.KeyValuePair) bool { + for _, kv := range kvs { + if kv.Key == MmapEnabledKey && kv.Value == "true" { + return true + } + } + return false +} + +func IsFieldMmapEnabled(schema *schemapb.CollectionSchema, fieldID int64) bool { + for _, field := range schema.GetFields() { + if field.GetFieldID() == fieldID { + return IsMmapEnabled(field.GetTypeParams()...) + } + } + return false +} + const ( // LatestVerision is the magic number for watch latest revision LatestRevision = int64(-1)