From 27cc9f263003ecb7648e85b86e44e9f0d007d50d Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Thu, 6 Jun 2024 17:37:51 +0800 Subject: [PATCH] enhance: Support analyze data (#33651) issue: #30633 Signed-off-by: Cai Zhang Co-authored-by: chasingegg --- Makefile | 1 + configs/milvus.yaml | 29 + internal/core/CMakeLists.txt | 6 + internal/core/src/CMakeLists.txt | 1 + internal/core/src/clustering/CMakeLists.txt | 24 + .../core/src/clustering/KmeansClustering.cpp | 532 ++++++ .../core/src/clustering/KmeansClustering.h | 157 ++ internal/core/src/clustering/analyze_c.cpp | 157 ++ internal/core/src/clustering/analyze_c.h | 40 + internal/core/src/clustering/file_utils.h | 69 + .../src/clustering/milvus_clustering.pc.in | 9 + internal/core/src/clustering/type_c.h | 17 + internal/core/src/clustering/types.h | 41 + internal/core/src/common/Consts.h | 5 + internal/core/src/common/EasyAssert.h | 1 + internal/core/src/indexbuilder/types.h | 1 + internal/core/src/pb/CMakeLists.txt | 2 +- .../core/src/storage/DiskFileManagerImpl.cpp | 1 + internal/core/src/storage/FileManager.h | 5 + internal/core/unittest/CMakeLists.txt | 8 + .../core/unittest/test_kmeans_clustering.cpp | 321 ++++ .../unittest/test_utils/storage_test_utils.h | 58 + internal/datacoord/analyze_meta.go | 182 ++ internal/datacoord/analyze_meta_test.go | 267 +++ internal/datacoord/garbage_collector.go | 2 +- internal/datacoord/import_checker_test.go | 1 + internal/datacoord/import_scheduler_test.go | 1 + internal/datacoord/import_util_test.go | 3 + internal/datacoord/index_builder.go | 575 ------ internal/datacoord/index_builder_test.go | 1611 ----------------- internal/datacoord/index_meta.go | 10 +- internal/datacoord/index_meta_test.go | 18 +- internal/datacoord/index_service.go | 10 +- internal/datacoord/indexnode_manager.go | 87 +- internal/datacoord/indexnode_manager_test.go | 10 +- internal/datacoord/meta.go | 29 +- internal/datacoord/meta_test.go | 5 +- internal/datacoord/mock_compaction_meta.go | 129 ++ internal/datacoord/mock_worker_manager.go | 335 ++++ internal/datacoord/server.go | 14 +- internal/datacoord/task_analyze.go | 286 +++ internal/datacoord/task_index.go | 337 ++++ internal/datacoord/task_scheduler.go | 296 +++ internal/datacoord/task_scheduler_test.go | 1423 +++++++++++++++ internal/datacoord/types.go | 40 + internal/datacoord/util.go | 17 + .../distributed/indexnode/client/client.go | 18 + .../indexnode/client/client_test.go | 18 + internal/distributed/indexnode/service.go | 12 + .../distributed/indexnode/service_test.go | 60 +- internal/indexnode/indexnode.go | 26 +- internal/indexnode/indexnode_mock.go | 73 + internal/indexnode/indexnode_service.go | 265 ++- internal/indexnode/indexnode_service_test.go | 130 ++ internal/indexnode/indexnode_test.go | 6 +- internal/indexnode/task.go | 565 +----- internal/indexnode/task_analyze.go | 215 +++ internal/indexnode/task_index.go | 570 ++++++ internal/indexnode/task_scheduler.go | 32 +- internal/indexnode/task_scheduler_test.go | 42 +- internal/indexnode/task_test.go | 210 ++- internal/indexnode/taskinfo_ops.go | 167 +- internal/indexnode/util_test.go | 17 + internal/metastore/catalog.go | 5 + internal/metastore/kv/datacoord/constant.go | 1 + internal/metastore/kv/datacoord/kv_catalog.go | 38 + internal/metastore/kv/datacoord/util.go | 4 + .../metastore/mocks/mock_datacoord_catalog.go | 141 ++ internal/mocks/mock_indexnode.go | 165 ++ internal/mocks/mock_indexnode_client.go | 210 +++ internal/proto/clustering.proto | 55 + internal/proto/index_coord.proto | 156 +- internal/util/analyzecgowrapper/analyze.go | 116 ++ internal/util/analyzecgowrapper/helper.go | 55 + internal/util/mock/grpc_indexnode_client.go | 12 + pkg/common/common.go | 5 + pkg/util/paramtable/component_param.go | 215 ++- pkg/util/paramtable/component_param_test.go | 28 + pkg/util/typeutil/schema.go | 24 + pkg/util/typeutil/schema_test.go | 34 + scripts/generate_proto.sh | 5 +- tests/go_client/go.mod | 4 +- 82 files changed, 7942 insertions(+), 2930 deletions(-) create mode 100644 internal/core/src/clustering/CMakeLists.txt create mode 100644 internal/core/src/clustering/KmeansClustering.cpp create mode 100644 internal/core/src/clustering/KmeansClustering.h create mode 100644 internal/core/src/clustering/analyze_c.cpp create mode 100644 internal/core/src/clustering/analyze_c.h create mode 100644 internal/core/src/clustering/file_utils.h create mode 100644 internal/core/src/clustering/milvus_clustering.pc.in create mode 100644 internal/core/src/clustering/type_c.h create mode 100644 internal/core/src/clustering/types.h create mode 100644 internal/core/unittest/test_kmeans_clustering.cpp create mode 100644 internal/datacoord/analyze_meta.go create mode 100644 internal/datacoord/analyze_meta_test.go delete mode 100644 internal/datacoord/index_builder.go delete mode 100644 internal/datacoord/index_builder_test.go create mode 100644 internal/datacoord/mock_worker_manager.go create mode 100644 internal/datacoord/task_analyze.go create mode 100644 internal/datacoord/task_index.go create mode 100644 internal/datacoord/task_scheduler.go create mode 100644 internal/datacoord/task_scheduler_test.go create mode 100644 internal/datacoord/types.go create mode 100644 internal/indexnode/task_analyze.go create mode 100644 internal/indexnode/task_index.go create mode 100644 internal/proto/clustering.proto create mode 100644 internal/util/analyzecgowrapper/analyze.go create mode 100644 internal/util/analyzecgowrapper/helper.go diff --git a/Makefile b/Makefile index ba553753c3..b02415b316 100644 --- a/Makefile +++ b/Makefile @@ -465,6 +465,7 @@ generate-mockery-datacoord: getdeps $(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=SubCluster --dir=internal/datacoord --filename=mock_subcluster.go --output=internal/datacoord --structname=MockSubCluster --with-expecter --inpackage $(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage + $(INSTALL_PATH)/mockery --name=WorkerManager --dir=internal/datacoord --filename=mock_worker_manager.go --output=internal/datacoord --structname=MockWorkerManager --with-expecter --inpackage generate-mockery-datanode: getdeps $(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage diff --git a/configs/milvus.yaml b/configs/milvus.yaml index d6415e1a76..cc87eeb1e0 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -448,6 +448,30 @@ dataCoord: rpcTimeout: 10 maxParallelTaskNum: 10 workerMaxParallelTaskNum: 2 + clustering: + enable: true + autoEnable: false + triggerInterval: 600 + stateCheckInterval: 10 + gcInterval: 600 + minInterval: 3600 + maxInterval: 259200 + newDataRatioThreshold: 0.2 + newDataSizeThreshold: 512m + timeout: 7200 + dropTolerance: 86400 + # clustering compaction will try best to distribute data into segments with size range in [preferSegmentSize, maxSegmentSize]. + # data will be clustered by preferSegmentSize, if a cluster is larger than maxSegmentSize, will spilt it into multi segment + # buffer between (preferSegmentSize, maxSegmentSize) is left for new data in the same cluster(range), to avoid globally redistribute too often + preferSegmentSize: 512m + maxSegmentSize: 1024m + maxTrainSizeRatio: 0.8 # max data size ratio in analyze, if data is larger than it, will down sampling to meet this limit + maxCentroidsNum: 10240 + minCentroidsNum: 16 + minClusterSizeRatio: 0.01 + maxClusterSizeRatio: 10 + maxClusterSize: 5g + levelzero: forceTrigger: minSize: 8388608 # The minmum size in bytes to force trigger a LevelZero Compaction, default as 8MB @@ -535,6 +559,9 @@ dataNode: slot: slotCap: 2 # The maximum number of tasks(e.g. compaction, importing) allowed to run concurrently on a datanode. + clusteringCompaction: + memoryBufferRatio: 0.1 # The ratio of memory buffer of clustering compaction. Data larger than threshold will be spilled to storage. + # Configures the system log output. log: level: info # Only supports debug, info, warn, error, panic, or fatal. Default 'info'. @@ -615,6 +642,8 @@ common: traceLogMode: 0 # trace request info bloomFilterSize: 100000 # bloom filter initial size maxBloomFalsePositive: 0.001 # max false positive rate for bloom filter + usePartitionKeyAsClusteringKey: false + useVectorAsClusteringKey: false # QuotaConfig, configurations of Milvus quota and limits. # By default, we enable: diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index 1c22ee5120..407220304a 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -293,6 +293,12 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/indexbuilder/ FILES_MATCHING PATTERN "*_c.h" ) +# Install clustering +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/clustering/ + DESTINATION include/clustering + FILES_MATCHING PATTERN "*_c.h" +) + # Install common install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/common/ DESTINATION include/common diff --git a/internal/core/src/CMakeLists.txt b/internal/core/src/CMakeLists.txt index 50e34e99b5..c6da67afb1 100644 --- a/internal/core/src/CMakeLists.txt +++ b/internal/core/src/CMakeLists.txt @@ -32,5 +32,6 @@ add_subdirectory( index ) add_subdirectory( query ) add_subdirectory( segcore ) add_subdirectory( indexbuilder ) +add_subdirectory( clustering ) add_subdirectory( exec ) add_subdirectory( bitset ) diff --git a/internal/core/src/clustering/CMakeLists.txt b/internal/core/src/clustering/CMakeLists.txt new file mode 100644 index 0000000000..40833d9ef2 --- /dev/null +++ b/internal/core/src/clustering/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License + + +set(CLUSTERING_FILES + analyze_c.cpp + KmeansClustering.cpp + ) + +milvus_add_pkg_config("milvus_clustering") +add_library(milvus_clustering SHARED ${CLUSTERING_FILES}) + +# link order matters +target_link_libraries(milvus_clustering milvus_index) + +install(TARGETS milvus_clustering DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/internal/core/src/clustering/KmeansClustering.cpp b/internal/core/src/clustering/KmeansClustering.cpp new file mode 100644 index 0000000000..4c88a058d1 --- /dev/null +++ b/internal/core/src/clustering/KmeansClustering.cpp @@ -0,0 +1,532 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "index/VectorDiskIndex.h" + +#include "common/Tracer.h" +#include "common/Utils.h" +#include "config/ConfigKnowhere.h" +#include "index/Meta.h" +#include "index/Utils.h" +#include "knowhere/cluster/cluster_factory.h" +#include "knowhere/comp/time_recorder.h" +#include "clustering/KmeansClustering.h" +#include "segcore/SegcoreConfig.h" +#include "storage/LocalChunkManagerSingleton.h" +#include "storage/Util.h" +#include "common/Consts.h" +#include "common/RangeSearchHelper.h" +#include "clustering/types.h" +#include "clustering/file_utils.h" +#include + +namespace milvus::clustering { + +KmeansClustering::KmeansClustering( + const storage::FileManagerContext& file_manager_context) { + file_manager_ = + std::make_unique(file_manager_context); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); + int64_t collection_id = file_manager_context.fieldDataMeta.collection_id; + int64_t partition_id = file_manager_context.fieldDataMeta.partition_id; + msg_header_ = fmt::format( + "collection: {}, partition: {} ", collection_id, partition_id); +} + +template +void +KmeansClustering::FetchDataFiles(uint8_t* buf, + const int64_t expected_train_size, + const int64_t expected_remote_file_size, + const std::vector& files, + const int64_t dim, + int64_t& offset) { + // CacheRawDataToMemory mostly used as pull files from one segment + // So we could assume memory is always enough for theses cases + // But in clustering when we sample train data, first pre-allocate the large buffer(size controlled by config) for future knowhere usage + // And we will have tmp memory usage at pulling stage, pull file(tmp memory) + memcpy to pre-allocated buffer, limit the batch here + auto batch = size_t(DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE); + int64_t fetched_file_size = 0; + + for (size_t i = 0; i < files.size(); i += batch) { + size_t start = i; + size_t end = std::min(files.size(), i + batch); + std::vector group_files(files.begin() + start, + files.begin() + end); + auto field_datas = file_manager_->CacheRawDataToMemory(group_files); + + for (auto& data : field_datas) { + size_t size = std::min(expected_train_size - offset, data->Size()); + if (size <= 0) { + break; + } + fetched_file_size += size; + std::memcpy(buf + offset, data->Data(), size); + offset += size; + data.reset(); + } + } + AssertInfo(fetched_file_size == expected_remote_file_size, + "file size inconsistent, expected: {}, actual: {}", + expected_remote_file_size, + fetched_file_size); +} + +template +void +KmeansClustering::SampleTrainData( + const std::vector& segment_ids, + const std::map>& segment_file_paths, + const std::map& segment_num_rows, + const int64_t expected_train_size, + const int64_t dim, + const bool random_sample, + uint8_t* buf) { + int64_t offset = 0; + std::vector files; + + if (random_sample) { + for (auto& [segment_id, segment_files] : segment_file_paths) { + for (auto& segment_file : segment_files) { + files.emplace_back(segment_file); + } + } + // shuffle files + std::shuffle(files.begin(), files.end(), std::mt19937()); + FetchDataFiles( + buf, expected_train_size, expected_train_size, files, dim, offset); + return; + } + + // pick all segment_ids, no shuffle + // and pull data once each segment to reuse the id mapping for assign stage + for (auto i = 0; i < segment_ids.size(); i++) { + if (offset == expected_train_size) { + break; + } + int64_t cur_segment_id = segment_ids[i]; + files = segment_file_paths.at(cur_segment_id); + std::sort(files.begin(), + files.end(), + [](const std::string& a, const std::string& b) { + return std::stol(a.substr(a.find_last_of("/") + 1)) < + std::stol(b.substr(b.find_last_of("/") + 1)); + }); + FetchDataFiles(buf, + expected_train_size, + segment_num_rows.at(cur_segment_id) * dim * sizeof(T), + files, + dim, + offset); + } +} + +template +milvus::proto::clustering::ClusteringCentroidsStats +KmeansClustering::CentroidsToPB(const T* centroids, + const int64_t num_clusters, + const int64_t dim) { + milvus::proto::clustering::ClusteringCentroidsStats stats; + for (auto i = 0; i < num_clusters; i++) { + milvus::proto::schema::VectorField* vector_field = + stats.add_centroids(); + vector_field->set_dim(dim); + milvus::proto::schema::FloatArray* float_array = + vector_field->mutable_float_vector(); + for (auto j = 0; j < dim; j++) { + float_array->add_data(float(centroids[i * dim + j])); + } + } + return stats; +} + +std::vector +KmeansClustering::CentroidIdMappingToPB( + const uint32_t* centroid_id_mapping, + const std::vector& segment_ids, + const int64_t trained_segments_num, + const std::map& num_row_map, + const int64_t num_clusters) { + auto compute_num_in_centroid = [&](const uint32_t* centroid_id_mapping, + uint64_t start, + uint64_t end) -> std::vector { + std::vector num_vectors(num_clusters, 0); + for (uint64_t i = start; i < end; ++i) { + num_vectors[centroid_id_mapping[i]]++; + } + return num_vectors; + }; + std::vector + stats_arr; + stats_arr.reserve(trained_segments_num); + int64_t cur_offset = 0; + for (auto i = 0; i < trained_segments_num; i++) { + milvus::proto::clustering::ClusteringCentroidIdMappingStats stats; + auto num_offset = num_row_map.at(segment_ids[i]); + for (auto j = 0; j < num_offset; j++) { + stats.add_centroid_id_mapping(centroid_id_mapping[cur_offset + j]); + } + auto num_vectors = compute_num_in_centroid( + centroid_id_mapping, cur_offset, cur_offset + num_offset); + for (uint64_t j = 0; j < num_clusters; j++) { + stats.add_num_in_centroid(num_vectors[j]); + } + cur_offset += num_offset; + stats_arr.emplace_back(stats); + } + return stats_arr; +} + +template +bool +KmeansClustering::IsDataSkew( + const milvus::proto::clustering::AnalyzeInfo& config, + const int64_t dim, + std::vector& num_in_each_centroid) { + auto min_cluster_ratio = config.min_cluster_ratio(); + auto max_cluster_ratio = config.max_cluster_ratio(); + auto max_cluster_size = config.max_cluster_size(); + std::sort(num_in_each_centroid.begin(), num_in_each_centroid.end()); + size_t avg_size = + std::accumulate( + num_in_each_centroid.begin(), num_in_each_centroid.end(), 0) / + (num_in_each_centroid.size()); + if (num_in_each_centroid.front() <= min_cluster_ratio * avg_size) { + LOG_INFO(msg_header_ + "minimum cluster too small: {}, avg: {}", + num_in_each_centroid.front(), + avg_size); + return true; + } + if (num_in_each_centroid.back() >= max_cluster_ratio * avg_size) { + LOG_INFO(msg_header_ + "maximum cluster too large: {}, avg: {}", + num_in_each_centroid.back(), + avg_size); + return true; + } + if (num_in_each_centroid.back() * dim * sizeof(T) >= max_cluster_size) { + LOG_INFO(msg_header_ + "maximum cluster size too large: {}B", + num_in_each_centroid.back() * dim * sizeof(T)); + return true; + } + return false; +} + +template +void +KmeansClustering::StreamingAssignandUpload( + knowhere::Cluster& cluster_node, + const milvus::proto::clustering::AnalyzeInfo& config, + const milvus::proto::clustering::ClusteringCentroidsStats& centroid_stats, + const std::vector< + milvus::proto::clustering::ClusteringCentroidIdMappingStats>& + id_mapping_stats, + const std::vector& segment_ids, + const std::map>& insert_files, + const std::map& num_rows, + const int64_t dim, + const int64_t trained_segments_num, + const int64_t num_clusters) { + auto byte_size = centroid_stats.ByteSizeLong(); + std::unique_ptr data = std::make_unique(byte_size); + centroid_stats.SerializeToArray(data.get(), byte_size); + std::unordered_map remote_paths_to_size; + LOG_INFO(msg_header_ + "start upload cluster centroids file"); + AddClusteringResultFiles( + file_manager_->GetChunkManager().get(), + data.get(), + byte_size, + GetRemoteCentroidsObjectPrefix() + "/" + std::string(CENTROIDS_NAME), + remote_paths_to_size); + cluster_result_.centroid_path = + GetRemoteCentroidsObjectPrefix() + "/" + std::string(CENTROIDS_NAME); + cluster_result_.centroid_file_size = + remote_paths_to_size.at(cluster_result_.centroid_path); + remote_paths_to_size.clear(); + LOG_INFO(msg_header_ + "upload cluster centroids file done"); + + LOG_INFO(msg_header_ + "start upload cluster id mapping file"); + std::vector num_vectors_each_centroid(num_clusters, 0); + + auto serializeIdMappingAndUpload = [&](const int64_t segment_id, + const milvus::proto::clustering:: + ClusteringCentroidIdMappingStats& + id_mapping_pb) { + auto byte_size = id_mapping_pb.ByteSizeLong(); + std::unique_ptr data = + std::make_unique(byte_size); + id_mapping_pb.SerializeToArray(data.get(), byte_size); + AddClusteringResultFiles( + file_manager_->GetChunkManager().get(), + data.get(), + byte_size, + GetRemoteCentroidIdMappingObjectPrefix(segment_id) + "/" + + std::string(OFFSET_MAPPING_NAME), + remote_paths_to_size); + LOG_INFO( + msg_header_ + + "upload segment {} cluster id mapping file with size {} B done", + segment_id, + byte_size); + }; + + for (size_t i = 0; i < segment_ids.size(); i++) { + int64_t segment_id = segment_ids[i]; + // id mapping has been computed, just upload to remote + if (i < trained_segments_num) { + serializeIdMappingAndUpload(segment_id, id_mapping_stats[i]); + for (int64_t j = 0; j < num_clusters; ++j) { + num_vectors_each_centroid[j] += + id_mapping_stats[i].num_in_centroid(j); + } + } else { // streaming download raw data, assign id mapping, then upload + int64_t num_row = num_rows.at(segment_id); + std::unique_ptr buf = std::make_unique(num_row * dim); + int64_t offset = 0; + FetchDataFiles(reinterpret_cast(buf.get()), + INT64_MAX, + num_row * dim * sizeof(T), + insert_files.at(segment_id), + dim, + offset); + auto dataset = GenDataset(num_row, dim, buf.release()); + dataset->SetIsOwner(true); + auto res = cluster_node.Assign(*dataset); + if (!res.has_value()) { + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to kmeans assign: {}: {}", + KnowhereStatusString(res.error()), + res.what())); + } + res.value()->SetIsOwner(true); + auto id_mapping = + reinterpret_cast(res.value()->GetTensor()); + + auto id_mapping_pb = CentroidIdMappingToPB( + id_mapping, {segment_id}, 1, num_rows, num_clusters)[0]; + for (int64_t j = 0; j < num_clusters; ++j) { + num_vectors_each_centroid[j] += + id_mapping_pb.num_in_centroid(j); + } + serializeIdMappingAndUpload(segment_id, id_mapping_pb); + } + } + if (IsDataSkew(config, dim, num_vectors_each_centroid)) { + LOG_INFO(msg_header_ + "data skew! skip clustering"); + // remove uploaded files + RemoveClusteringResultFiles(file_manager_->GetChunkManager().get(), + remote_paths_to_size); + // skip clustering, nothing takes affect + throw SegcoreError(ErrorCode::ClusterSkip, + "data skew! skip clustering"); + } + LOG_INFO(msg_header_ + "upload cluster id mapping file done"); + cluster_result_.id_mappings = std::move(remote_paths_to_size); + is_runned_ = true; +} + +template +void +KmeansClustering::Run(const milvus::proto::clustering::AnalyzeInfo& config) { + std::map> insert_files; + for (const auto& pair : config.insert_files()) { + std::vector segment_files( + pair.second.insert_files().begin(), + pair.second.insert_files().end()); + insert_files[pair.first] = segment_files; + } + + std::map num_rows(config.num_rows().begin(), + config.num_rows().end()); + auto num_clusters = config.num_clusters(); + AssertInfo(num_clusters > 0, "num clusters must larger than 0"); + auto train_size = config.train_size(); + AssertInfo(train_size > 0, "train size must larger than 0"); + auto dim = config.dim(); + auto min_cluster_ratio = config.min_cluster_ratio(); + AssertInfo(min_cluster_ratio > 0 && min_cluster_ratio < 1, + "min cluster ratio must larger than 0, less than 1"); + auto max_cluster_ratio = config.max_cluster_ratio(); + AssertInfo(max_cluster_ratio > 1, "max cluster ratio must larger than 1"); + auto max_cluster_size = config.max_cluster_size(); + AssertInfo(max_cluster_size > 0, "max cluster size must larger than 0"); + + auto cluster_node_obj = + knowhere::ClusterFactory::Instance().Create(KMEANS_CLUSTER); + knowhere::Cluster cluster_node; + if (cluster_node_obj.has_value()) { + cluster_node = std::move(cluster_node_obj.value()); + } else { + auto err = cluster_node_obj.error(); + if (err == knowhere::Status::invalid_cluster_error) { + throw SegcoreError(ErrorCode::ClusterSkip, cluster_node_obj.what()); + } + throw SegcoreError(ErrorCode::KnowhereError, cluster_node_obj.what()); + } + + size_t data_num = 0; + std::vector segment_ids; + for (auto& [segment_id, num_row_each_segment] : num_rows) { + data_num += num_row_each_segment; + segment_ids.emplace_back(segment_id); + AssertInfo(insert_files.find(segment_id) != insert_files.end(), + "segment id {} not exist in insert files", + segment_id); + } + size_t trained_segments_num = 0; + + size_t data_size = data_num * dim * sizeof(T); + size_t train_num = train_size / sizeof(T) / dim; + bool random_sample = true; + // make train num equal to data num + if (train_num >= data_num) { + train_num = data_num; + random_sample = + false; // all data are used for training, no need to random sampling + trained_segments_num = segment_ids.size(); + } + if (train_num < num_clusters) { + LOG_WARN(msg_header_ + + "kmeans train num: {} less than num_clusters: {}, skip " + "clustering", + train_num, + num_clusters); + throw SegcoreError(ErrorCode::ClusterSkip, + "sample data num less than num clusters"); + } + + size_t train_size_final = train_num * dim * sizeof(T); + knowhere::TimeRecorder rc(msg_header_ + "kmeans clustering", + 2 /* log level: info */); + // if data_num larger than max_train_size, we need to sample to make train data fits in memory + // otherwise just load all the data for kmeans training + LOG_INFO(msg_header_ + "pull and sample {}GB data out of {}GB data", + train_size_final / 1024.0 / 1024.0 / 1024.0, + data_size / 1024.0 / 1024.0 / 1024.0); + auto buf = std::make_unique(train_size_final); + SampleTrainData(segment_ids, + insert_files, + num_rows, + train_size_final, + dim, + random_sample, + buf.get()); + rc.RecordSection("sample done"); + + auto dataset = GenDataset(train_num, dim, buf.release()); + dataset->SetIsOwner(true); + + LOG_INFO(msg_header_ + "train data num: {}, dim: {}, num_clusters: {}", + train_num, + dim, + num_clusters); + knowhere::Json train_conf; + train_conf[NUM_CLUSTERS] = num_clusters; + // inside knowhere, we will record each kmeans iteration duration + // return id mapping + auto res = cluster_node.Train(*dataset, train_conf); + if (!res.has_value()) { + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to kmeans train: {}: {}", + KnowhereStatusString(res.error()), + res.what())); + } + res.value()->SetIsOwner(true); + rc.RecordSection("clustering train done"); + dataset.reset(); // release train data + + auto centroid_id_mapping = + reinterpret_cast(res.value()->GetTensor()); + + auto centroids_res = cluster_node.GetCentroids(); + if (!centroids_res.has_value()) { + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to get centroids: {}: {}", + KnowhereStatusString(res.error()), + res.what())); + } + // centroids owned by cluster_node + centroids_res.value()->SetIsOwner(false); + auto centroids = + reinterpret_cast(centroids_res.value()->GetTensor()); + + auto centroid_stats = CentroidsToPB(centroids, num_clusters, dim); + auto id_mapping_stats = CentroidIdMappingToPB(centroid_id_mapping, + segment_ids, + trained_segments_num, + num_rows, + num_clusters); + // upload + StreamingAssignandUpload(cluster_node, + config, + centroid_stats, + id_mapping_stats, + segment_ids, + insert_files, + num_rows, + dim, + trained_segments_num, + num_clusters); + rc.RecordSection("clustering result upload done"); + rc.ElapseFromBegin("clustering done"); +} + +template void +KmeansClustering::StreamingAssignandUpload( + knowhere::Cluster& cluster_node, + const milvus::proto::clustering::AnalyzeInfo& config, + const milvus::proto::clustering::ClusteringCentroidsStats& centroid_stats, + const std::vector< + milvus::proto::clustering::ClusteringCentroidIdMappingStats>& + id_mapping_stats, + const std::vector& segment_ids, + const std::map>& insert_files, + const std::map& num_rows, + const int64_t dim, + const int64_t trained_segments_num, + const int64_t num_clusters); + +template void +KmeansClustering::FetchDataFiles(uint8_t* buf, + const int64_t expected_train_size, + const int64_t expected_remote_file_size, + const std::vector& files, + const int64_t dim, + int64_t& offset); +template void +KmeansClustering::SampleTrainData( + const std::vector& segment_ids, + const std::map>& segment_file_paths, + const std::map& segment_num_rows, + const int64_t expected_train_size, + const int64_t dim, + const bool random_sample, + uint8_t* buf); + +template void +KmeansClustering::Run( + const milvus::proto::clustering::AnalyzeInfo& config); + +template milvus::proto::clustering::ClusteringCentroidsStats +KmeansClustering::CentroidsToPB(const float* centroids, + const int64_t num_clusters, + const int64_t dim); +template bool +KmeansClustering::IsDataSkew( + const milvus::proto::clustering::AnalyzeInfo& config, + const int64_t dim, + std::vector& num_in_each_centroid); + +} // namespace milvus::clustering diff --git a/internal/core/src/clustering/KmeansClustering.h b/internal/core/src/clustering/KmeansClustering.h new file mode 100644 index 0000000000..bfb7d0e4a1 --- /dev/null +++ b/internal/core/src/clustering/KmeansClustering.h @@ -0,0 +1,157 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "storage/MemFileManagerImpl.h" +#include "storage/space.h" +#include "pb/clustering.pb.h" +#include "knowhere/cluster/cluster_factory.h" + +namespace milvus::clustering { + +// after clustering result uploaded, return result meta for golang usage +struct ClusteringResultMeta { + std::string centroid_path; // centroid result path + int64_t centroid_file_size; // centroid result size + std::unordered_map + id_mappings; // id mapping result path/size for each segment +}; + +class KmeansClustering { + public: + explicit KmeansClustering( + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); + + // every time is a brand new kmeans training + template + void + Run(const milvus::proto::clustering::AnalyzeInfo& config); + + // should never be called before run + ClusteringResultMeta + GetClusteringResultMeta() { + if (!is_runned_) { + throw SegcoreError( + ErrorCode::UnexpectedError, + "clustering result is not ready before kmeans run"); + } + return cluster_result_; + } + + // ut + inline std::string + GetRemoteCentroidsObjectPrefix() const { + auto index_meta_ = file_manager_->GetIndexMeta(); + auto field_meta_ = file_manager_->GetFieldDataMeta(); + return file_manager_->GetChunkManager()->GetRootPath() + "/" + + std::string(ANALYZE_ROOT_PATH) + "/" + + std::to_string(index_meta_.build_id) + "/" + + std::to_string(index_meta_.index_version) + "/" + + std::to_string(field_meta_.collection_id) + "/" + + std::to_string(field_meta_.partition_id) + "/" + + std::to_string(field_meta_.field_id); + } + + inline std::string + GetRemoteCentroidIdMappingObjectPrefix(int64_t segment_id) const { + auto index_meta_ = file_manager_->GetIndexMeta(); + auto field_meta_ = file_manager_->GetFieldDataMeta(); + return file_manager_->GetChunkManager()->GetRootPath() + "/" + + std::string(ANALYZE_ROOT_PATH) + "/" + + std::to_string(index_meta_.build_id) + "/" + + std::to_string(index_meta_.index_version) + "/" + + std::to_string(field_meta_.collection_id) + "/" + + std::to_string(field_meta_.partition_id) + "/" + + std::to_string(field_meta_.field_id) + "/" + + std::to_string(segment_id); + } + + ~KmeansClustering() = default; + + private: + template + void + StreamingAssignandUpload( + knowhere::Cluster& cluster_node, + const milvus::proto::clustering::AnalyzeInfo& config, + const milvus::proto::clustering::ClusteringCentroidsStats& + centroid_stats, + const std::vector< + milvus::proto::clustering::ClusteringCentroidIdMappingStats>& + id_mapping_stats, + const std::vector& segment_ids, + const std::map>& insert_files, + const std::map& num_rows, + const int64_t dim, + const int64_t trained_segments_num, + const int64_t num_clusters); + + template + void + FetchDataFiles(uint8_t* buf, + const int64_t expected_train_size, + const int64_t expected_remote_file_size, + const std::vector& files, + const int64_t dim, + int64_t& offset); + + // given all possible segments, sample data to buffer + template + void + SampleTrainData( + const std::vector& segment_ids, + const std::map>& segment_file_paths, + const std::map& segment_num_rows, + const int64_t expected_train_size, + const int64_t dim, + const bool random_sample, + uint8_t* buf); + + // transform centroids result to PB format for future usage of golang side + template + milvus::proto::clustering::ClusteringCentroidsStats + CentroidsToPB(const T* centroids, + const int64_t num_clusters, + const int64_t dim); + + // transform flattened id mapping result to several PB files by each segment for future usage of golang side + std::vector + CentroidIdMappingToPB(const uint32_t* centroid_id_mapping, + const std::vector& segment_ids, + const int64_t trained_segments_num, + const std::map& num_row_map, + const int64_t num_clusters); + + template + bool + IsDataSkew(const milvus::proto::clustering::AnalyzeInfo& config, + const int64_t dim, + std::vector& num_in_each_centroid); + + std::unique_ptr file_manager_; + ClusteringResultMeta cluster_result_; + bool is_runned_ = false; + std::string msg_header_; +}; + +using KmeansClusteringPtr = std::unique_ptr; +} // namespace milvus::clustering diff --git a/internal/core/src/clustering/analyze_c.cpp b/internal/core/src/clustering/analyze_c.cpp new file mode 100644 index 0000000000..8df1aec71b --- /dev/null +++ b/internal/core/src/clustering/analyze_c.cpp @@ -0,0 +1,157 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#ifdef __linux__ +#include +#endif + +#include "analyze_c.h" +#include "common/type_c.h" +#include "type_c.h" +#include "types.h" +#include "index/Utils.h" +#include "index/Meta.h" +#include "storage/Util.h" +#include "pb/clustering.pb.h" +#include "clustering/KmeansClustering.h" + +using namespace milvus; + +milvus::storage::StorageConfig +get_storage_config(const milvus::proto::clustering::StorageConfig& config) { + auto storage_config = milvus::storage::StorageConfig(); + storage_config.address = std::string(config.address()); + storage_config.bucket_name = std::string(config.bucket_name()); + storage_config.access_key_id = std::string(config.access_keyid()); + storage_config.access_key_value = std::string(config.secret_access_key()); + storage_config.root_path = std::string(config.root_path()); + storage_config.storage_type = std::string(config.storage_type()); + storage_config.cloud_provider = std::string(config.cloud_provider()); + storage_config.iam_endpoint = std::string(config.iamendpoint()); + storage_config.cloud_provider = std::string(config.cloud_provider()); + storage_config.useSSL = config.usessl(); + storage_config.sslCACert = config.sslcacert(); + storage_config.useIAM = config.useiam(); + storage_config.region = config.region(); + storage_config.useVirtualHost = config.use_virtual_host(); + storage_config.requestTimeoutMs = config.request_timeout_ms(); + return storage_config; +} + +CStatus +Analyze(CAnalyze* res_analyze, + const uint8_t* serialized_analyze_info, + const uint64_t len) { + try { + auto analyze_info = + std::make_unique(); + auto res = analyze_info->ParseFromArray(serialized_analyze_info, len); + AssertInfo(res, "Unmarshall analyze info failed"); + auto field_type = + static_cast(analyze_info->field_schema().data_type()); + auto field_id = analyze_info->field_schema().fieldid(); + + // init file manager + milvus::storage::FieldDataMeta field_meta{analyze_info->collectionid(), + analyze_info->partitionid(), + 0, + field_id}; + + milvus::storage::IndexMeta index_meta{ + 0, field_id, analyze_info->buildid(), analyze_info->version()}; + auto storage_config = + get_storage_config(analyze_info->storage_config()); + auto chunk_manager = + milvus::storage::CreateChunkManager(storage_config); + + milvus::storage::FileManagerContext fileManagerContext( + field_meta, index_meta, chunk_manager); + + if (field_type != DataType::VECTOR_FLOAT) { + throw SegcoreError( + DataTypeInvalid, + fmt::format("invalid data type for clustering is {}", + std::to_string(int(field_type)))); + } + auto clusteringJob = + std::make_unique( + fileManagerContext); + + clusteringJob->Run(*analyze_info); + *res_analyze = clusteringJob.release(); + auto status = CStatus(); + status.error_code = Success; + status.error_msg = ""; + return status; + } catch (SegcoreError& e) { + auto status = CStatus(); + status.error_code = e.get_error_code(); + status.error_msg = strdup(e.what()); + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + return status; + } +} + +CStatus +DeleteAnalyze(CAnalyze analyze) { + auto status = CStatus(); + try { + AssertInfo(analyze, "failed to delete analyze, passed index was null"); + auto real_analyze = + reinterpret_cast(analyze); + delete real_analyze; + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} + +CStatus +GetAnalyzeResultMeta(CAnalyze analyze, + char** centroid_path, + int64_t* centroid_file_size, + void* id_mapping_paths, + int64_t* id_mapping_sizes) { + auto status = CStatus(); + try { + AssertInfo(analyze, + "failed to serialize analyze to binary set, passed index " + "was null"); + auto real_analyze = + reinterpret_cast(analyze); + auto res = real_analyze->GetClusteringResultMeta(); + *centroid_path = res.centroid_path.data(); + *centroid_file_size = res.centroid_file_size; + + auto& map_ = res.id_mappings; + const char** id_mapping_paths_ = (const char**)id_mapping_paths; + size_t i = 0; + for (auto it = map_.begin(); it != map_.end(); ++it, i++) { + id_mapping_paths_[i] = it->first.data(); + id_mapping_sizes[i] = it->second; + } + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} diff --git a/internal/core/src/clustering/analyze_c.h b/internal/core/src/clustering/analyze_c.h new file mode 100644 index 0000000000..0bfa845a64 --- /dev/null +++ b/internal/core/src/clustering/analyze_c.h @@ -0,0 +1,40 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include "common/type_c.h" +#include "common/binary_set_c.h" +#include "clustering/type_c.h" + +CStatus +Analyze(CAnalyze* res_analyze, + const uint8_t* serialized_analyze_info, + const uint64_t len); + +CStatus +DeleteAnalyze(CAnalyze analyze); + +CStatus +GetAnalyzeResultMeta(CAnalyze analyze, + char** centroid_path, + int64_t* centroid_file_size, + void* id_mapping_paths, + int64_t* id_mapping_sizes); + +#ifdef __cplusplus +}; +#endif diff --git a/internal/core/src/clustering/file_utils.h b/internal/core/src/clustering/file_utils.h new file mode 100644 index 0000000000..097d57e84b --- /dev/null +++ b/internal/core/src/clustering/file_utils.h @@ -0,0 +1,69 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "common/type_c.h" +#include +#include "storage/ThreadPools.h" + +#include "common/FieldData.h" +#include "common/LoadInfo.h" +#include "knowhere/comp/index_param.h" +#include "parquet/schema.h" +#include "storage/PayloadStream.h" +#include "storage/FileManager.h" +#include "storage/BinlogReader.h" +#include "storage/ChunkManager.h" +#include "storage/DataCodec.h" +#include "storage/Types.h" +#include "storage/space.h" + +namespace milvus::clustering { + +void +AddClusteringResultFiles(milvus::storage::ChunkManager* remote_chunk_manager, + const uint8_t* data, + const int64_t data_size, + const std::string& remote_prefix, + std::unordered_map& map) { + remote_chunk_manager->Write( + remote_prefix, const_cast(data), data_size); + map[remote_prefix] = data_size; +} + +void +RemoveClusteringResultFiles( + milvus::storage::ChunkManager* remote_chunk_manager, + const std::unordered_map& map) { + auto& pool = ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::MIDDLE); + std::vector> futures; + + for (auto& [file_path, file_size] : map) { + futures.push_back(pool.Submit( + [&, path = file_path]() { remote_chunk_manager->Remove(path); })); + } + std::exception_ptr first_exception = nullptr; + for (auto& future : futures) { + try { + future.get(); + } catch (...) { + if (!first_exception) { + first_exception = std::current_exception(); + } + } + } + if (first_exception) { + std::rethrow_exception(first_exception); + } +} + +} // namespace milvus::clustering diff --git a/internal/core/src/clustering/milvus_clustering.pc.in b/internal/core/src/clustering/milvus_clustering.pc.in new file mode 100644 index 0000000000..d1bbb3d3ba --- /dev/null +++ b/internal/core/src/clustering/milvus_clustering.pc.in @@ -0,0 +1,9 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: Milvus Clustering +Description: Clustering modules for Milvus +Version: @MILVUS_VERSION@ + +Libs: -L${libdir} -lmilvus_clustering +Cflags: -I${includedir} diff --git a/internal/core/src/clustering/type_c.h b/internal/core/src/clustering/type_c.h new file mode 100644 index 0000000000..51d8d61665 --- /dev/null +++ b/internal/core/src/clustering/type_c.h @@ -0,0 +1,17 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "common/type_c.h" + +typedef void* CAnalyze; +typedef void* CAnalyzeInfo; diff --git a/internal/core/src/clustering/types.h b/internal/core/src/clustering/types.h new file mode 100644 index 0000000000..57e1890861 --- /dev/null +++ b/internal/core/src/clustering/types.h @@ -0,0 +1,41 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "common/Types.h" +#include "index/Index.h" +#include "storage/Types.h" + +struct AnalyzeInfo { + int64_t collection_id; + int64_t partition_id; + int64_t field_id; + int64_t task_id; + int64_t version; + std::string field_name; + milvus::DataType field_type; + int64_t dim; + int64_t num_clusters; + int64_t train_size; + std::map> + insert_files; // segment_id->files + std::map num_rows; + milvus::storage::StorageConfig storage_config; + milvus::Config config; +}; diff --git a/internal/core/src/common/Consts.h b/internal/core/src/common/Consts.h index 44d7d5559c..e5ee8d6765 100644 --- a/internal/core/src/common/Consts.h +++ b/internal/core/src/common/Consts.h @@ -38,6 +38,11 @@ const char INDEX_BUILD_ID_KEY[] = "indexBuildID"; const char INDEX_ROOT_PATH[] = "index_files"; const char RAWDATA_ROOT_PATH[] = "raw_datas"; +const char ANALYZE_ROOT_PATH[] = "analyze_stats"; +const char CENTROIDS_NAME[] = "centroids"; +const char OFFSET_MAPPING_NAME[] = "offset_mapping"; +const char NUM_CLUSTERS[] = "num_clusters"; +const char KMEANS_CLUSTER[] = "KMEANS"; const char VEC_OPT_FIELDS[] = "opt_fields"; const char DEFAULT_PLANNODE_ID[] = "0"; diff --git a/internal/core/src/common/EasyAssert.h b/internal/core/src/common/EasyAssert.h index 0182e256e5..c23fd0f1c2 100644 --- a/internal/core/src/common/EasyAssert.h +++ b/internal/core/src/common/EasyAssert.h @@ -60,6 +60,7 @@ enum ErrorCode { UnistdError = 2030, MetricTypeNotMatch = 2031, DimNotMatch = 2032, + ClusterSkip = 2033, KnowhereError = 2100, }; diff --git a/internal/core/src/indexbuilder/types.h b/internal/core/src/indexbuilder/types.h index e1c656e25d..aed989ce59 100644 --- a/internal/core/src/indexbuilder/types.h +++ b/internal/core/src/indexbuilder/types.h @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include diff --git a/internal/core/src/pb/CMakeLists.txt b/internal/core/src/pb/CMakeLists.txt index 35726d9c24..d49637702d 100644 --- a/internal/core/src/pb/CMakeLists.txt +++ b/internal/core/src/pb/CMakeLists.txt @@ -15,7 +15,7 @@ file(GLOB_RECURSE milvus_proto_srcs "${CMAKE_CURRENT_SOURCE_DIR}/*.cc") add_library(milvus_proto STATIC ${milvus_proto_srcs} - ) +) message(STATUS "milvus proto sources: " ${milvus_proto_srcs}) target_link_libraries( milvus_proto PUBLIC ${CONAN_LIBS} ) diff --git a/internal/core/src/storage/DiskFileManagerImpl.cpp b/internal/core/src/storage/DiskFileManagerImpl.cpp index a97f1503f9..0328b7e8b2 100644 --- a/internal/core/src/storage/DiskFileManagerImpl.cpp +++ b/internal/core/src/storage/DiskFileManagerImpl.cpp @@ -442,6 +442,7 @@ SortByPath(std::vector& paths) { std::stol(b.substr(b.find_last_of("/") + 1)); }); } + template std::string DiskFileManagerImpl::CacheRawDataToDisk(std::vector remote_files) { diff --git a/internal/core/src/storage/FileManager.h b/internal/core/src/storage/FileManager.h index a0d94cfc58..816beb2e8a 100644 --- a/internal/core/src/storage/FileManager.h +++ b/internal/core/src/storage/FileManager.h @@ -130,6 +130,11 @@ class FileManagerImpl : public knowhere::FileManager { return index_meta_; } + virtual ChunkManagerPtr + GetChunkManager() const { + return rcm_; + } + virtual std::string GetRemoteIndexObjectPrefix() const { return rcm_->GetRootPath() + "/" + std::string(INDEX_ROOT_PATH) + "/" + diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 7abde651f3..39e242cec4 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -69,6 +69,12 @@ set(MILVUS_TEST_FILES test_regex_query.cpp ) +if ( INDEX_ENGINE STREQUAL "cardinal" ) + set(MILVUS_TEST_FILES + ${MILVUS_TEST_FILES} + test_kmeans_clustering.cpp) +endif() + if ( BUILD_DISK_ANN STREQUAL "ON" ) set(MILVUS_TEST_FILES ${MILVUS_TEST_FILES} @@ -121,6 +127,7 @@ if (LINUX) milvus_segcore milvus_storage milvus_indexbuilder + milvus_clustering milvus_common ) install(TARGETS index_builder_test DESTINATION unittest) @@ -135,6 +142,7 @@ target_link_libraries(all_tests milvus_segcore milvus_storage milvus_indexbuilder + milvus_clustering pthread milvus_common milvus_exec diff --git a/internal/core/unittest/test_kmeans_clustering.cpp b/internal/core/unittest/test_kmeans_clustering.cpp new file mode 100644 index 0000000000..5cf35ef65d --- /dev/null +++ b/internal/core/unittest/test_kmeans_clustering.cpp @@ -0,0 +1,321 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include + +#include "common/Tracer.h" +#include "common/EasyAssert.h" +#include "index/InvertedIndexTantivy.h" +#include "storage/Util.h" +#include "storage/InsertData.h" +#include "clustering/KmeansClustering.h" +#include "storage/LocalChunkManagerSingleton.h" +#include "test_utils/indexbuilder_test_utils.h" +#include "test_utils/storage_test_utils.h" +#include "index/Meta.h" + +using namespace milvus; + +void +ReadPBFile(std::string& file_path, google::protobuf::Message& message) { + std::ifstream infile; + infile.open(file_path.data(), std::ios_base::binary); + if (infile.fail()) { + std::stringstream err_msg; + err_msg << "Error: open local file '" << file_path << " failed, " + << strerror(errno); + throw SegcoreError(FileOpenFailed, err_msg.str()); + } + + infile.seekg(0, std::ios::beg); + if (!message.ParseFromIstream(&infile)) { + std::stringstream err_msg; + err_msg << "Error: parse pb file '" << file_path << " failed, " + << strerror(errno); + throw SegcoreError(FileReadFailed, err_msg.str()); + } + infile.close(); +} + +milvus::proto::clustering::AnalyzeInfo +transforConfigToPB(const Config& config) { + milvus::proto::clustering::AnalyzeInfo analyze_info; + analyze_info.set_num_clusters(config["num_clusters"]); + analyze_info.set_max_cluster_ratio(config["max_cluster_ratio"]); + analyze_info.set_min_cluster_ratio(config["min_cluster_ratio"]); + analyze_info.set_max_cluster_size(config["max_cluster_size"]); + auto& num_rows = *analyze_info.mutable_num_rows(); + for (const auto& [k, v] : + milvus::index::GetValueFromConfig>( + config, "num_rows") + .value()) { + num_rows[k] = v; + } + auto& insert_files = *analyze_info.mutable_insert_files(); + auto insert_files_map = + milvus::index::GetValueFromConfig< + std::map>>(config, "insert_files") + .value(); + for (const auto& [k, v] : insert_files_map) { + for (auto i = 0; i < v.size(); i++) + insert_files[k].add_insert_files(v[i]); + } + analyze_info.set_dim(config["dim"]); + analyze_info.set_train_size(config["train_size"]); + return analyze_info; +} + +template +void +CheckResultCorrectness( + const milvus::clustering::KmeansClusteringPtr& clusteringJob, + int64_t segment_id, + int64_t segment_id2, + int64_t dim, + int64_t nb, + int expected_num_clusters, + bool check_centroids) { + std::string centroids_path_prefix = + clusteringJob->GetRemoteCentroidsObjectPrefix(); + std::string centroids_name = std::string(CENTROIDS_NAME); + std::string centroid_path = centroids_path_prefix + "/" + centroids_name; + milvus::proto::clustering::ClusteringCentroidsStats stats; + ReadPBFile(centroid_path, stats); + std::vector centroids; + for (const auto& centroid : stats.centroids()) { + const auto& float_vector = centroid.float_vector(); + for (float value : float_vector.data()) { + centroids.emplace_back(T(value)); + } + } + ASSERT_EQ(centroids.size(), expected_num_clusters * dim); + std::string offset_mapping_name = std::string(OFFSET_MAPPING_NAME); + std::string centroid_id_mapping_path = + clusteringJob->GetRemoteCentroidIdMappingObjectPrefix(segment_id) + + "/" + offset_mapping_name; + milvus::proto::clustering::ClusteringCentroidIdMappingStats mapping_stats; + std::string centroid_id_mapping_path2 = + clusteringJob->GetRemoteCentroidIdMappingObjectPrefix(segment_id2) + + "/" + offset_mapping_name; + milvus::proto::clustering::ClusteringCentroidIdMappingStats mapping_stats2; + ReadPBFile(centroid_id_mapping_path, mapping_stats); + ReadPBFile(centroid_id_mapping_path2, mapping_stats2); + + std::vector centroid_id_mapping; + std::vector num_in_centroid; + for (const auto id : mapping_stats.centroid_id_mapping()) { + centroid_id_mapping.emplace_back(id); + ASSERT_TRUE(id < expected_num_clusters); + } + ASSERT_EQ(centroid_id_mapping.size(), nb); + for (const auto num : mapping_stats.num_in_centroid()) { + num_in_centroid.emplace_back(num); + } + ASSERT_EQ( + std::accumulate(num_in_centroid.begin(), num_in_centroid.end(), 0), nb); + // second id mapping should be the same with the first one since the segment data is the same + if (check_centroids) { + for (int64_t i = 0; i < mapping_stats2.centroid_id_mapping_size(); + i++) { + ASSERT_EQ(mapping_stats2.centroid_id_mapping(i), + centroid_id_mapping[i]); + } + for (int64_t i = 0; i < mapping_stats2.num_in_centroid_size(); i++) { + ASSERT_EQ(mapping_stats2.num_in_centroid(i), num_in_centroid[i]); + } + } +} + +template +void +test_run() { + int64_t collection_id = 1; + int64_t partition_id = 2; + int64_t segment_id = 3; + int64_t segment_id2 = 4; + int64_t field_id = 101; + int64_t index_build_id = 1000; + int64_t index_version = 10000; + int64_t dim = 100; + int64_t nb = 10000; + + auto field_meta = + gen_field_meta(collection_id, partition_id, segment_id, field_id); + auto index_meta = + gen_index_meta(segment_id, field_id, index_build_id, index_version); + + std::string root_path = "/tmp/test-kmeans-clustering/"; + auto storage_config = gen_local_storage_config(root_path); + auto cm = storage::CreateChunkManager(storage_config); + + std::vector data_gen(nb * dim); + for (int64_t i = 0; i < nb * dim; ++i) { + data_gen[i] = rand(); + } + auto field_data = storage::CreateFieldData(dtype, dim); + field_data->FillFieldData(data_gen.data(), data_gen.size() / dim); + storage::InsertData insert_data(field_data); + insert_data.SetFieldDataMeta(field_meta); + insert_data.SetTimestamps(0, 100); + auto serialized_bytes = insert_data.Serialize(storage::Remote); + + auto get_binlog_path = [=](int64_t log_id) { + return fmt::format("{}/{}/{}/{}/{}", + collection_id, + partition_id, + segment_id, + field_id, + log_id); + }; + + auto log_path = get_binlog_path(0); + auto cm_w = ChunkManagerWrapper(cm); + cm_w.Write(log_path, serialized_bytes.data(), serialized_bytes.size()); + storage::FileManagerContext ctx(field_meta, index_meta, cm); + + std::map> remote_files; + std::map num_rows; + // two segments + remote_files[segment_id] = {log_path}; + remote_files[segment_id2] = {log_path}; + num_rows[segment_id] = nb; + num_rows[segment_id2] = nb; + Config config; + config["max_cluster_ratio"] = 10.0; + config["max_cluster_size"] = 5L * 1024 * 1024 * 1024; + + // no need to sample train data + { + config["min_cluster_ratio"] = 0.01; + config["insert_files"] = remote_files; + config["num_clusters"] = 8; + config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB + config["dim"] = dim; + config["num_rows"] = num_rows; + auto clusteringJob = + std::make_unique(ctx); + clusteringJob->Run(transforConfigToPB(config)); + CheckResultCorrectness(clusteringJob, + segment_id, + segment_id2, + dim, + nb, + config["num_clusters"], + true); + } + { + config["min_cluster_ratio"] = 0.01; + config["insert_files"] = remote_files; + config["num_clusters"] = 200; + config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB + config["dim"] = dim; + config["num_rows"] = num_rows; + auto clusteringJob = + std::make_unique(ctx); + clusteringJob->Run(transforConfigToPB(config)); + CheckResultCorrectness(clusteringJob, + segment_id, + segment_id2, + dim, + nb, + config["num_clusters"], + true); + } + // num clusters larger than train num + { + EXPECT_THROW( + try { + config["min_cluster_ratio"] = 0.01; + config["insert_files"] = remote_files; + config["num_clusters"] = 100000; + config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB + config["dim"] = dim; + config["num_rows"] = num_rows; + auto clusteringJob = + std::make_unique(ctx); + clusteringJob->Run(transforConfigToPB(config)); + } catch (SegcoreError& e) { + ASSERT_EQ(e.get_error_code(), ErrorCode::ClusterSkip); + throw e; + }, + SegcoreError); + } + + // data skew + { + EXPECT_THROW( + try { + config["min_cluster_ratio"] = 0.98; + config["insert_files"] = remote_files; + config["num_clusters"] = 8; + config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB + config["dim"] = dim; + config["num_rows"] = num_rows; + auto clusteringJob = + std::make_unique(ctx); + clusteringJob->Run(transforConfigToPB(config)); + } catch (SegcoreError& e) { + ASSERT_EQ(e.get_error_code(), ErrorCode::ClusterSkip); + throw e; + }, + SegcoreError); + } + + // need to sample train data case1 + { + config["min_cluster_ratio"] = 0.01; + config["insert_files"] = remote_files; + config["num_clusters"] = 8; + config["train_size"] = 1536L * 1024; // 1.5MB + config["dim"] = dim; + config["num_rows"] = num_rows; + auto clusteringJob = + std::make_unique(ctx); + + clusteringJob->Run(transforConfigToPB(config)); + CheckResultCorrectness(clusteringJob, + segment_id, + segment_id2, + dim, + nb, + config["num_clusters"], + true); + } + // need to sample train data case2 + { + config["min_cluster_ratio"] = 0.01; + config["insert_files"] = remote_files; + config["num_clusters"] = 8; + config["train_size"] = 6L * 1024 * 1024; // 6MB + config["dim"] = dim; + config["num_rows"] = num_rows; + auto clusteringJob = + std::make_unique(ctx); + + clusteringJob->Run(transforConfigToPB(config)); + CheckResultCorrectness(clusteringJob, + segment_id, + segment_id2, + dim, + nb, + config["num_clusters"], + true); + } +} + +TEST(MajorCompaction, Naive) { + test_run(); +} \ No newline at end of file diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index 7eca359f30..2cd38ef2db 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -23,6 +23,7 @@ #include "storage/Types.h" #include "storage/InsertData.h" #include "storage/ThreadPools.h" +#include using milvus::DataType; using milvus::FieldDataPtr; @@ -137,4 +138,61 @@ PutFieldData(milvus::storage::ChunkManager* remote_chunk_manager, return remote_paths_to_size; } +auto +gen_field_meta(int64_t collection_id = 1, + int64_t partition_id = 2, + int64_t segment_id = 3, + int64_t field_id = 101) -> milvus::storage::FieldDataMeta { + return milvus::storage::FieldDataMeta{ + .collection_id = collection_id, + .partition_id = partition_id, + .segment_id = segment_id, + .field_id = field_id, + }; +} + +auto +gen_index_meta(int64_t segment_id = 3, + int64_t field_id = 101, + int64_t index_build_id = 1000, + int64_t index_version = 10000) -> milvus::storage::IndexMeta { + return milvus::storage::IndexMeta{ + .segment_id = segment_id, + .field_id = field_id, + .build_id = index_build_id, + .index_version = index_version, + }; +} + +auto +gen_local_storage_config(const std::string& root_path) + -> milvus::storage::StorageConfig { + auto ret = milvus::storage::StorageConfig{}; + ret.storage_type = "local"; + ret.root_path = root_path; + return ret; +} + +struct ChunkManagerWrapper { + ChunkManagerWrapper(milvus::storage::ChunkManagerPtr cm) : cm_(cm) { + } + + ~ChunkManagerWrapper() { + for (const auto& file : written_) { + cm_->Remove(file); + } + + boost::filesystem::remove_all(cm_->GetRootPath()); + } + + void + Write(const std::string& filepath, void* buf, uint64_t len) { + written_.insert(filepath); + cm_->Write(filepath, buf, len); + } + + const milvus::storage::ChunkManagerPtr cm_; + std::unordered_set written_; +}; + } // namespace diff --git a/internal/datacoord/analyze_meta.go b/internal/datacoord/analyze_meta.go new file mode 100644 index 0000000000..3e543e2b9c --- /dev/null +++ b/internal/datacoord/analyze_meta.go @@ -0,0 +1,182 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "fmt" + "sync" + + "github.com/golang/protobuf/proto" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/metastore" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/timerecord" +) + +type analyzeMeta struct { + sync.RWMutex + + ctx context.Context + catalog metastore.DataCoordCatalog + + // taskID -> analyzeStats + // TODO: when to mark as dropped? + tasks map[int64]*indexpb.AnalyzeTask +} + +func newAnalyzeMeta(ctx context.Context, catalog metastore.DataCoordCatalog) (*analyzeMeta, error) { + mt := &analyzeMeta{ + ctx: ctx, + catalog: catalog, + tasks: make(map[int64]*indexpb.AnalyzeTask), + } + + if err := mt.reloadFromKV(); err != nil { + return nil, err + } + return mt, nil +} + +func (m *analyzeMeta) reloadFromKV() error { + record := timerecord.NewTimeRecorder("analyzeMeta-reloadFromKV") + + // load analyze stats + analyzeTasks, err := m.catalog.ListAnalyzeTasks(m.ctx) + if err != nil { + log.Warn("analyzeMeta reloadFromKV load analyze tasks failed", zap.Error(err)) + return err + } + + for _, analyzeTask := range analyzeTasks { + m.tasks[analyzeTask.TaskID] = analyzeTask + } + log.Info("analyzeMeta reloadFromKV done", zap.Duration("duration", record.ElapseSpan())) + return nil +} + +func (m *analyzeMeta) saveTask(newTask *indexpb.AnalyzeTask) error { + if err := m.catalog.SaveAnalyzeTask(m.ctx, newTask); err != nil { + return err + } + m.tasks[newTask.TaskID] = newTask + return nil +} + +func (m *analyzeMeta) GetTask(taskID int64) *indexpb.AnalyzeTask { + m.RLock() + defer m.RUnlock() + + return m.tasks[taskID] +} + +func (m *analyzeMeta) AddAnalyzeTask(task *indexpb.AnalyzeTask) error { + m.Lock() + defer m.Unlock() + + log.Info("add analyze task", zap.Int64("taskID", task.TaskID), + zap.Int64("collectionID", task.CollectionID), zap.Int64("partitionID", task.PartitionID)) + return m.saveTask(task) +} + +func (m *analyzeMeta) DropAnalyzeTask(taskID int64) error { + m.Lock() + defer m.Unlock() + + log.Info("drop analyze task", zap.Int64("taskID", taskID)) + if err := m.catalog.DropAnalyzeTask(m.ctx, taskID); err != nil { + log.Warn("drop analyze task by catalog failed", zap.Int64("taskID", taskID), + zap.Error(err)) + return err + } + + delete(m.tasks, taskID) + return nil +} + +func (m *analyzeMeta) UpdateVersion(taskID int64) error { + m.Lock() + defer m.Unlock() + + t, ok := m.tasks[taskID] + if !ok { + return fmt.Errorf("there is no task with taskID: %d", taskID) + } + + cloneT := proto.Clone(t).(*indexpb.AnalyzeTask) + cloneT.Version++ + log.Info("update task version", zap.Int64("taskID", taskID), zap.Int64("newVersion", cloneT.Version)) + return m.saveTask(cloneT) +} + +func (m *analyzeMeta) BuildingTask(taskID, nodeID int64) error { + m.Lock() + defer m.Unlock() + + t, ok := m.tasks[taskID] + if !ok { + return fmt.Errorf("there is no task with taskID: %d", taskID) + } + + cloneT := proto.Clone(t).(*indexpb.AnalyzeTask) + cloneT.NodeID = nodeID + cloneT.State = indexpb.JobState_JobStateInProgress + log.Info("task will be building", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID)) + + return m.saveTask(cloneT) +} + +func (m *analyzeMeta) FinishTask(taskID int64, result *indexpb.AnalyzeResult) error { + m.Lock() + defer m.Unlock() + + t, ok := m.tasks[taskID] + if !ok { + return fmt.Errorf("there is no task with taskID: %d", taskID) + } + + log.Info("finish task meta...", zap.Int64("taskID", taskID), zap.String("state", result.GetState().String()), + zap.String("failReason", result.GetFailReason())) + + cloneT := proto.Clone(t).(*indexpb.AnalyzeTask) + cloneT.State = result.GetState() + cloneT.FailReason = result.GetFailReason() + cloneT.CentroidsFile = result.GetCentroidsFile() + return m.saveTask(cloneT) +} + +func (m *analyzeMeta) GetAllTasks() map[int64]*indexpb.AnalyzeTask { + m.RLock() + defer m.RUnlock() + + return m.tasks +} + +func (m *analyzeMeta) CheckCleanAnalyzeTask(taskID UniqueID) (bool, *indexpb.AnalyzeTask) { + m.RLock() + defer m.RUnlock() + + if t, ok := m.tasks[taskID]; ok { + if t.State == indexpb.JobState_JobStateFinished { + return true, t + } + return false, t + } + return true, nil +} diff --git a/internal/datacoord/analyze_meta_test.go b/internal/datacoord/analyze_meta_test.go new file mode 100644 index 0000000000..fdecb64796 --- /dev/null +++ b/internal/datacoord/analyze_meta_test.go @@ -0,0 +1,267 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/metastore/mocks" + "github.com/milvus-io/milvus/internal/proto/indexpb" +) + +type AnalyzeMetaSuite struct { + suite.Suite + + collectionID int64 + partitionID int64 + fieldID int64 + segmentIDs []int64 +} + +func (s *AnalyzeMetaSuite) initParams() { + s.collectionID = 100 + s.partitionID = 101 + s.fieldID = 102 + s.segmentIDs = []int64{1000, 1001, 1002, 1003} +} + +func (s *AnalyzeMetaSuite) Test_AnalyzeMeta() { + s.initParams() + + catalog := mocks.NewDataCoordCatalog(s.T()) + catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return([]*indexpb.AnalyzeTask{ + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 1, + State: indexpb.JobState_JobStateNone, + }, + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 2, + State: indexpb.JobState_JobStateInit, + }, + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 3, + State: indexpb.JobState_JobStateInProgress, + }, + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 4, + State: indexpb.JobState_JobStateRetry, + }, + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 5, + State: indexpb.JobState_JobStateFinished, + }, + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 6, + State: indexpb.JobState_JobStateFailed, + }, + }, nil) + + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil) + catalog.EXPECT().DropAnalyzeTask(mock.Anything, mock.Anything).Return(nil) + + ctx := context.Background() + + am, err := newAnalyzeMeta(ctx, catalog) + s.NoError(err) + s.Equal(6, len(am.GetAllTasks())) + + s.Run("GetTask", func() { + t := am.GetTask(1) + s.NotNil(t) + + t = am.GetTask(100) + s.Nil(t) + }) + + s.Run("AddAnalyzeTask", func() { + t := &indexpb.AnalyzeTask{ + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 7, + } + + err := am.AddAnalyzeTask(t) + s.NoError(err) + s.Equal(7, len(am.GetAllTasks())) + + err = am.AddAnalyzeTask(t) + s.NoError(err) + s.Equal(7, len(am.GetAllTasks())) + }) + + s.Run("DropAnalyzeTask", func() { + err := am.DropAnalyzeTask(7) + s.NoError(err) + s.Equal(6, len(am.GetAllTasks())) + }) + + s.Run("UpdateVersion", func() { + err := am.UpdateVersion(1) + s.NoError(err) + s.Equal(int64(1), am.GetTask(1).Version) + }) + + s.Run("BuildingTask", func() { + err := am.BuildingTask(1, 1) + s.NoError(err) + s.Equal(indexpb.JobState_JobStateInProgress, am.GetTask(1).State) + }) + + s.Run("FinishTask", func() { + err := am.FinishTask(1, &indexpb.AnalyzeResult{ + TaskID: 1, + State: indexpb.JobState_JobStateFinished, + }) + s.NoError(err) + s.Equal(indexpb.JobState_JobStateFinished, am.GetTask(1).State) + }) +} + +func (s *AnalyzeMetaSuite) Test_failCase() { + s.initParams() + + catalog := mocks.NewDataCoordCatalog(s.T()) + catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, errors.New("error")).Once() + ctx := context.Background() + am, err := newAnalyzeMeta(ctx, catalog) + s.Error(err) + s.Nil(am) + + catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return([]*indexpb.AnalyzeTask{ + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 1, + State: indexpb.JobState_JobStateInit, + }, + { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 2, + State: indexpb.JobState_JobStateFinished, + }, + }, nil) + am, err = newAnalyzeMeta(ctx, catalog) + s.NoError(err) + s.NotNil(am) + s.Equal(2, len(am.GetAllTasks())) + + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(errors.New("error")) + catalog.EXPECT().DropAnalyzeTask(mock.Anything, mock.Anything).Return(errors.New("error")) + s.Run("AddAnalyzeTask", func() { + t := &indexpb.AnalyzeTask{ + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 1111, + } + err := am.AddAnalyzeTask(t) + s.Error(err) + s.Nil(am.GetTask(1111)) + }) + + s.Run("DropAnalyzeTask", func() { + err := am.DropAnalyzeTask(1) + s.Error(err) + s.NotNil(am.GetTask(1)) + }) + + s.Run("UpdateVersion", func() { + err := am.UpdateVersion(777) + s.Error(err) + + err = am.UpdateVersion(1) + s.Error(err) + s.Equal(int64(0), am.GetTask(1).Version) + }) + + s.Run("BuildingTask", func() { + err := am.BuildingTask(777, 1) + s.Error(err) + + err = am.BuildingTask(1, 1) + s.Error(err) + s.Equal(int64(0), am.GetTask(1).NodeID) + s.Equal(indexpb.JobState_JobStateInit, am.GetTask(1).State) + }) + + s.Run("FinishTask", func() { + err := am.FinishTask(777, nil) + s.Error(err) + + err = am.FinishTask(1, &indexpb.AnalyzeResult{ + TaskID: 1, + State: indexpb.JobState_JobStateFinished, + }) + s.Error(err) + s.Equal(indexpb.JobState_JobStateInit, am.GetTask(1).State) + }) + + s.Run("CheckCleanAnalyzeTask", func() { + canRecycle, t := am.CheckCleanAnalyzeTask(1) + s.False(canRecycle) + s.Equal(indexpb.JobState_JobStateInit, t.GetState()) + + canRecycle, t = am.CheckCleanAnalyzeTask(777) + s.True(canRecycle) + s.Nil(t) + + canRecycle, t = am.CheckCleanAnalyzeTask(2) + s.True(canRecycle) + s.Equal(indexpb.JobState_JobStateFinished, t.GetState()) + }) +} + +func TestAnalyzeMeta(t *testing.T) { + suite.Run(t, new(AnalyzeMetaSuite)) +} diff --git a/internal/datacoord/garbage_collector.go b/internal/datacoord/garbage_collector.go index b62312d55d..4a75878682 100644 --- a/internal/datacoord/garbage_collector.go +++ b/internal/datacoord/garbage_collector.go @@ -620,7 +620,7 @@ func (gc *garbageCollector) recycleUnusedIndexFiles(ctx context.Context) { } logger = logger.With(zap.Int64("buildID", buildID)) logger.Info("garbageCollector will recycle index files") - canRecycle, segIdx := gc.meta.indexMeta.CleanSegmentIndex(buildID) + canRecycle, segIdx := gc.meta.indexMeta.CheckCleanSegmentIndex(buildID) if !canRecycle { // Even if the index is marked as deleted, the index file will not be recycled, wait for the next gc, // and delete all index files about the buildID at one time. diff --git a/internal/datacoord/import_checker_test.go b/internal/datacoord/import_checker_test.go index a19689ffc5..1cd1e10053 100644 --- a/internal/datacoord/import_checker_test.go +++ b/internal/datacoord/import_checker_test.go @@ -53,6 +53,7 @@ func (s *ImportCheckerSuite) SetupTest() { catalog.EXPECT().ListIndexes(mock.Anything).Return(nil, nil) catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return(nil, nil) catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) + catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) cluster := NewMockCluster(s.T()) alloc := NewNMockAllocator(s.T()) diff --git a/internal/datacoord/import_scheduler_test.go b/internal/datacoord/import_scheduler_test.go index eb15688bab..9bb542db66 100644 --- a/internal/datacoord/import_scheduler_test.go +++ b/internal/datacoord/import_scheduler_test.go @@ -57,6 +57,7 @@ func (s *ImportSchedulerSuite) SetupTest() { s.catalog.EXPECT().ListIndexes(mock.Anything).Return(nil, nil) s.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return(nil, nil) s.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) + s.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) s.cluster = NewMockCluster(s.T()) s.alloc = NewNMockAllocator(s.T()) diff --git a/internal/datacoord/import_util_test.go b/internal/datacoord/import_util_test.go index 410b1c4963..415880ba3d 100644 --- a/internal/datacoord/import_util_test.go +++ b/internal/datacoord/import_util_test.go @@ -153,6 +153,7 @@ func TestImportUtil_AssembleRequest(t *testing.T) { catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return(nil, nil) catalog.EXPECT().AddSegment(mock.Anything, mock.Anything).Return(nil) catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) + catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) alloc := NewNMockAllocator(t) alloc.EXPECT().allocN(mock.Anything).RunAndReturn(func(n int64) (int64, int64, error) { @@ -233,6 +234,7 @@ func TestImportUtil_CheckDiskQuota(t *testing.T) { catalog.EXPECT().ListChannelCheckpoint(mock.Anything).Return(nil, nil) catalog.EXPECT().AddSegment(mock.Anything, mock.Anything).Return(nil) catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) + catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) imeta, err := NewImportMeta(catalog) assert.NoError(t, err) @@ -408,6 +410,7 @@ func TestImportUtil_GetImportProgress(t *testing.T) { catalog.EXPECT().AddSegment(mock.Anything, mock.Anything).Return(nil) catalog.EXPECT().AlterSegments(mock.Anything, mock.Anything).Return(nil) catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) + catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) imeta, err := NewImportMeta(catalog) assert.NoError(t, err) diff --git a/internal/datacoord/index_builder.go b/internal/datacoord/index_builder.go deleted file mode 100644 index 9a83f2384c..0000000000 --- a/internal/datacoord/index_builder.go +++ /dev/null @@ -1,575 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package datacoord - -import ( - "context" - "path" - "sync" - "time" - - "go.uber.org/zap" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/querycoordv2/params" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/types" - itypeutil "github.com/milvus-io/milvus/internal/util/typeutil" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/indexparams" - "github.com/milvus-io/milvus/pkg/util/lock" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -type indexTaskState int32 - -const ( - // when we receive a index task - indexTaskInit indexTaskState = iota - // we've sent index task to scheduler, and wait for building index. - indexTaskInProgress - // task done, wait to be cleaned - indexTaskDone - // index task need to retry. - indexTaskRetry - - reqTimeoutInterval = time.Second * 10 -) - -var TaskStateNames = map[indexTaskState]string{ - 0: "Init", - 1: "InProgress", - 2: "Done", - 3: "Retry", -} - -func (x indexTaskState) String() string { - ret, ok := TaskStateNames[x] - if !ok { - return "None" - } - return ret -} - -type indexBuilder struct { - ctx context.Context - cancel context.CancelFunc - - wg sync.WaitGroup - taskMutex lock.RWMutex - scheduleDuration time.Duration - - // TODO @xiaocai2333: use priority queue - tasks map[int64]indexTaskState - notifyChan chan struct{} - - meta *meta - - policy buildIndexPolicy - nodeManager *IndexNodeManager - chunkManager storage.ChunkManager - indexEngineVersionManager IndexEngineVersionManager - handler Handler -} - -func newIndexBuilder( - ctx context.Context, - metaTable *meta, nodeManager *IndexNodeManager, - chunkManager storage.ChunkManager, - indexEngineVersionManager IndexEngineVersionManager, - handler Handler, -) *indexBuilder { - ctx, cancel := context.WithCancel(ctx) - - ib := &indexBuilder{ - ctx: ctx, - cancel: cancel, - meta: metaTable, - tasks: make(map[int64]indexTaskState), - notifyChan: make(chan struct{}, 1), - scheduleDuration: Params.DataCoordCfg.IndexTaskSchedulerInterval.GetAsDuration(time.Millisecond), - policy: defaultBuildIndexPolicy, - nodeManager: nodeManager, - chunkManager: chunkManager, - handler: handler, - indexEngineVersionManager: indexEngineVersionManager, - } - ib.reloadFromKV() - return ib -} - -func (ib *indexBuilder) Start() { - ib.wg.Add(1) - go ib.schedule() -} - -func (ib *indexBuilder) Stop() { - ib.cancel() - ib.wg.Wait() -} - -func (ib *indexBuilder) reloadFromKV() { - segments := ib.meta.GetAllSegmentsUnsafe() - for _, segment := range segments { - for _, segIndex := range ib.meta.indexMeta.getSegmentIndexes(segment.ID) { - if segIndex.IsDeleted { - continue - } - if segIndex.IndexState == commonpb.IndexState_Unissued { - ib.tasks[segIndex.BuildID] = indexTaskInit - } else if segIndex.IndexState == commonpb.IndexState_InProgress { - ib.tasks[segIndex.BuildID] = indexTaskInProgress - } - } - } -} - -// notify is an unblocked notify function -func (ib *indexBuilder) notify() { - select { - case ib.notifyChan <- struct{}{}: - default: - } -} - -func (ib *indexBuilder) enqueue(buildID UniqueID) { - defer ib.notify() - - ib.taskMutex.Lock() - defer ib.taskMutex.Unlock() - if _, ok := ib.tasks[buildID]; !ok { - ib.tasks[buildID] = indexTaskInit - } - log.Info("indexBuilder enqueue task", zap.Int64("buildID", buildID)) -} - -func (ib *indexBuilder) schedule() { - // receive notifyChan - // time ticker - log.Ctx(ib.ctx).Info("index builder schedule loop start") - defer ib.wg.Done() - ticker := time.NewTicker(ib.scheduleDuration) - defer ticker.Stop() - for { - select { - case <-ib.ctx.Done(): - log.Ctx(ib.ctx).Warn("index builder ctx done") - return - case _, ok := <-ib.notifyChan: - if ok { - ib.run() - } - // !ok means indexBuild is closed. - case <-ticker.C: - ib.run() - } - } -} - -func (ib *indexBuilder) run() { - ib.taskMutex.RLock() - buildIDs := make([]UniqueID, 0, len(ib.tasks)) - for tID := range ib.tasks { - buildIDs = append(buildIDs, tID) - } - ib.taskMutex.RUnlock() - if len(buildIDs) > 0 { - log.Ctx(ib.ctx).Info("index builder task schedule", zap.Int("task num", len(buildIDs))) - } - - ib.policy(buildIDs) - - for _, buildID := range buildIDs { - ok := ib.process(buildID) - if !ok { - log.Ctx(ib.ctx).Info("there is no idle indexing node, wait a minute...") - break - } - } -} - -func getBinLogIDs(segment *SegmentInfo, fieldID int64) []int64 { - binlogIDs := make([]int64, 0) - for _, fieldBinLog := range segment.GetBinlogs() { - if fieldBinLog.GetFieldID() == fieldID { - for _, binLog := range fieldBinLog.GetBinlogs() { - binlogIDs = append(binlogIDs, binLog.GetLogID()) - } - break - } - } - return binlogIDs -} - -func (ib *indexBuilder) process(buildID UniqueID) bool { - ib.taskMutex.RLock() - state := ib.tasks[buildID] - ib.taskMutex.RUnlock() - - updateStateFunc := func(buildID UniqueID, state indexTaskState) { - ib.taskMutex.Lock() - defer ib.taskMutex.Unlock() - ib.tasks[buildID] = state - } - - deleteFunc := func(buildID UniqueID) { - ib.taskMutex.Lock() - defer ib.taskMutex.Unlock() - delete(ib.tasks, buildID) - } - - meta, exist := ib.meta.indexMeta.GetIndexJob(buildID) - if !exist { - log.Ctx(ib.ctx).Debug("index task has not exist in meta table, remove task", zap.Int64("buildID", buildID)) - deleteFunc(buildID) - return true - } - - switch state { - case indexTaskInit: - segment := ib.meta.GetSegment(meta.SegmentID) - if !isSegmentHealthy(segment) || !ib.meta.indexMeta.IsIndexExist(meta.CollectionID, meta.IndexID) { - log.Ctx(ib.ctx).Info("task is no need to build index, remove it", zap.Int64("buildID", buildID)) - if err := ib.meta.indexMeta.DeleteTask(buildID); err != nil { - log.Ctx(ib.ctx).Warn("IndexCoord delete index failed", zap.Int64("buildID", buildID), zap.Error(err)) - return false - } - deleteFunc(buildID) - return true - } - indexParams := ib.meta.indexMeta.GetIndexParams(meta.CollectionID, meta.IndexID) - indexType := GetIndexType(indexParams) - if isFlatIndex(indexType) || meta.NumRows < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() { - log.Ctx(ib.ctx).Info("segment does not need index really", zap.Int64("buildID", buildID), - zap.Int64("segmentID", meta.SegmentID), zap.Int64("num rows", meta.NumRows)) - if err := ib.meta.indexMeta.FinishTask(&indexpb.IndexTaskInfo{ - BuildID: buildID, - State: commonpb.IndexState_Finished, - IndexFileKeys: nil, - SerializedSize: 0, - FailReason: "", - }); err != nil { - log.Ctx(ib.ctx).Warn("IndexCoord update index state fail", zap.Int64("buildID", buildID), zap.Error(err)) - return false - } - updateStateFunc(buildID, indexTaskDone) - return true - } - // peek client - // if all IndexNodes are executing task, wait for one of them to finish the task. - nodeID, client := ib.nodeManager.PeekClient(meta) - if client == nil { - log.Ctx(ib.ctx).WithRateGroup("dc.indexBuilder", 1, 60).RatedInfo(5, "index builder peek client error, there is no available") - return false - } - // update version and set nodeID - if err := ib.meta.indexMeta.UpdateVersion(buildID, nodeID); err != nil { - log.Ctx(ib.ctx).Warn("index builder update index version failed", zap.Int64("build", buildID), zap.Error(err)) - return false - } - - // vector index build needs information of optional scalar fields data - optionalFields := make([]*indexpb.OptionalFieldInfo, 0) - if Params.CommonCfg.EnableMaterializedView.GetAsBool() { - colSchema := ib.meta.GetCollection(meta.CollectionID).Schema - if colSchema != nil { - hasPartitionKey := typeutil.HasPartitionKey(colSchema) - if hasPartitionKey { - partitionKeyField, err := typeutil.GetPartitionKeyFieldSchema(colSchema) - if partitionKeyField == nil || err != nil { - log.Ctx(ib.ctx).Warn("index builder get partition key field failed", zap.Int64("build", buildID), zap.Error(err)) - } else { - if typeutil.IsFieldDataTypeSupportMaterializedView(partitionKeyField) { - optionalFields = append(optionalFields, &indexpb.OptionalFieldInfo{ - FieldID: partitionKeyField.FieldID, - FieldName: partitionKeyField.Name, - FieldType: int32(partitionKeyField.DataType), - DataIds: getBinLogIDs(segment, partitionKeyField.FieldID), - }) - } - } - } - } - } - - typeParams := ib.meta.indexMeta.GetTypeParams(meta.CollectionID, meta.IndexID) - - var storageConfig *indexpb.StorageConfig - if Params.CommonCfg.StorageType.GetValue() == "local" { - storageConfig = &indexpb.StorageConfig{ - RootPath: Params.LocalStorageCfg.Path.GetValue(), - StorageType: Params.CommonCfg.StorageType.GetValue(), - } - } else { - storageConfig = &indexpb.StorageConfig{ - Address: Params.MinioCfg.Address.GetValue(), - AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(), - SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(), - UseSSL: Params.MinioCfg.UseSSL.GetAsBool(), - SslCACert: Params.MinioCfg.SslCACert.GetValue(), - BucketName: Params.MinioCfg.BucketName.GetValue(), - RootPath: Params.MinioCfg.RootPath.GetValue(), - UseIAM: Params.MinioCfg.UseIAM.GetAsBool(), - IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(), - StorageType: Params.CommonCfg.StorageType.GetValue(), - Region: Params.MinioCfg.Region.GetValue(), - UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(), - CloudProvider: Params.MinioCfg.CloudProvider.GetValue(), - RequestTimeoutMs: Params.MinioCfg.RequestTimeoutMs.GetAsInt64(), - } - } - - fieldID := ib.meta.indexMeta.GetFieldIDByIndexID(meta.CollectionID, meta.IndexID) - binlogIDs := getBinLogIDs(segment, fieldID) - if isDiskANNIndex(GetIndexType(indexParams)) { - var err error - indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams) - if err != nil { - log.Ctx(ib.ctx).Warn("failed to append index build params", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID), zap.Error(err)) - } - } - var req *indexpb.CreateJobRequest - collectionInfo, err := ib.handler.GetCollection(ib.ctx, segment.GetCollectionID()) - if err != nil { - log.Ctx(ib.ctx).Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err)) - return false - } - - schema := collectionInfo.Schema - var field *schemapb.FieldSchema - - for _, f := range schema.Fields { - if f.FieldID == fieldID { - field = f - break - } - } - - dim, err := storage.GetDimFromParams(field.TypeParams) - if err != nil { - log.Ctx(ib.ctx).Warn("failed to get dim from field type params", - zap.String("field type", field.GetDataType().String()), zap.Error(err)) - // don't return, maybe field is scalar field or sparseFloatVector - } - if Params.CommonCfg.EnableStorageV2.GetAsBool() { - storePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue(), segment.GetID()) - if err != nil { - log.Ctx(ib.ctx).Warn("failed to get storage uri", zap.Error(err)) - return false - } - indexStorePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue()+"/index", segment.GetID()) - if err != nil { - log.Ctx(ib.ctx).Warn("failed to get storage uri", zap.Error(err)) - return false - } - - req = &indexpb.CreateJobRequest{ - ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), - IndexFilePrefix: path.Join(ib.chunkManager.RootPath(), common.SegmentIndexPath), - BuildID: buildID, - IndexVersion: meta.IndexVersion + 1, - StorageConfig: storageConfig, - IndexParams: indexParams, - TypeParams: typeParams, - NumRows: meta.NumRows, - CollectionID: segment.GetCollectionID(), - PartitionID: segment.GetPartitionID(), - SegmentID: segment.GetID(), - FieldID: fieldID, - FieldName: field.Name, - FieldType: field.DataType, - StorePath: storePath, - StoreVersion: segment.GetStorageVersion(), - IndexStorePath: indexStorePath, - Dim: int64(dim), - CurrentIndexVersion: ib.indexEngineVersionManager.GetCurrentIndexEngineVersion(), - DataIds: binlogIDs, - OptionalScalarFields: optionalFields, - Field: field, - } - } else { - req = &indexpb.CreateJobRequest{ - ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), - IndexFilePrefix: path.Join(ib.chunkManager.RootPath(), common.SegmentIndexPath), - BuildID: buildID, - IndexVersion: meta.IndexVersion + 1, - StorageConfig: storageConfig, - IndexParams: indexParams, - TypeParams: typeParams, - NumRows: meta.NumRows, - CurrentIndexVersion: ib.indexEngineVersionManager.GetCurrentIndexEngineVersion(), - DataIds: binlogIDs, - CollectionID: segment.GetCollectionID(), - PartitionID: segment.GetPartitionID(), - SegmentID: segment.GetID(), - FieldID: fieldID, - OptionalScalarFields: optionalFields, - Dim: int64(dim), - Field: field, - } - } - - if err := ib.assignTask(client, req); err != nil { - // need to release lock then reassign, so set task state to retry - log.Ctx(ib.ctx).Warn("index builder assign task to IndexNode failed", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID), zap.Error(err)) - updateStateFunc(buildID, indexTaskRetry) - return false - } - log.Ctx(ib.ctx).Info("index task assigned successfully", zap.Int64("buildID", buildID), - zap.Int64("segmentID", meta.SegmentID), zap.Int64("nodeID", nodeID)) - // update index meta state to InProgress - if err := ib.meta.indexMeta.BuildIndex(buildID); err != nil { - // need to release lock then reassign, so set task state to retry - log.Ctx(ib.ctx).Warn("index builder update index meta to InProgress failed", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID), zap.Error(err)) - updateStateFunc(buildID, indexTaskRetry) - return false - } - updateStateFunc(buildID, indexTaskInProgress) - - case indexTaskDone: - if !ib.dropIndexTask(buildID, meta.NodeID) { - return true - } - deleteFunc(buildID) - case indexTaskRetry: - if !ib.dropIndexTask(buildID, meta.NodeID) { - return true - } - updateStateFunc(buildID, indexTaskInit) - - default: - // state: in_progress - updateStateFunc(buildID, ib.getTaskState(buildID, meta.NodeID)) - } - return true -} - -func (ib *indexBuilder) getTaskState(buildID, nodeID UniqueID) indexTaskState { - client, exist := ib.nodeManager.GetClientByID(nodeID) - if exist { - ctx1, cancel := context.WithTimeout(ib.ctx, reqTimeoutInterval) - defer cancel() - response, err := client.QueryJobs(ctx1, &indexpb.QueryJobsRequest{ - ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), - BuildIDs: []int64{buildID}, - }) - if err != nil { - log.Ctx(ib.ctx).Warn("IndexCoord get jobs info from IndexNode fail", zap.Int64("nodeID", nodeID), - zap.Error(err)) - return indexTaskRetry - } - if response.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Ctx(ib.ctx).Warn("IndexCoord get jobs info from IndexNode fail", zap.Int64("nodeID", nodeID), - zap.Int64("buildID", buildID), zap.String("fail reason", response.GetStatus().GetReason())) - return indexTaskRetry - } - - // indexInfos length is always one. - for _, info := range response.GetIndexInfos() { - if info.GetBuildID() == buildID { - if info.GetState() == commonpb.IndexState_Failed || info.GetState() == commonpb.IndexState_Finished { - log.Ctx(ib.ctx).Info("this task has been finished", zap.Int64("buildID", info.GetBuildID()), - zap.String("index state", info.GetState().String())) - if err := ib.meta.indexMeta.FinishTask(info); err != nil { - log.Ctx(ib.ctx).Warn("IndexCoord update index state fail", zap.Int64("buildID", info.GetBuildID()), - zap.String("index state", info.GetState().String()), zap.Error(err)) - return indexTaskInProgress - } - return indexTaskDone - } else if info.GetState() == commonpb.IndexState_Retry || info.GetState() == commonpb.IndexState_IndexStateNone { - log.Ctx(ib.ctx).Info("this task should be retry", zap.Int64("buildID", buildID), zap.String("fail reason", info.GetFailReason())) - return indexTaskRetry - } - return indexTaskInProgress - } - } - log.Ctx(ib.ctx).Info("this task should be retry, indexNode does not have this task", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID)) - return indexTaskRetry - } - // !exist --> node down - log.Ctx(ib.ctx).Info("this task should be retry, indexNode is no longer exist", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID)) - return indexTaskRetry -} - -func (ib *indexBuilder) dropIndexTask(buildID, nodeID UniqueID) bool { - client, exist := ib.nodeManager.GetClientByID(nodeID) - if exist { - ctx1, cancel := context.WithTimeout(ib.ctx, reqTimeoutInterval) - defer cancel() - status, err := client.DropJobs(ctx1, &indexpb.DropJobsRequest{ - ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), - BuildIDs: []UniqueID{buildID}, - }) - if err != nil { - log.Ctx(ib.ctx).Warn("IndexCoord notify IndexNode drop the index task fail", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID), zap.Error(err)) - return false - } - if status.GetErrorCode() != commonpb.ErrorCode_Success { - log.Ctx(ib.ctx).Warn("IndexCoord notify IndexNode drop the index task fail", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID), zap.String("fail reason", status.GetReason())) - return false - } - log.Ctx(ib.ctx).Info("IndexCoord notify IndexNode drop the index task success", - zap.Int64("buildID", buildID), zap.Int64("nodeID", nodeID)) - return true - } - log.Ctx(ib.ctx).Info("IndexNode no longer exist, no need to drop index task", - zap.Int64("buildID", buildID), zap.Int64("nodeID", nodeID)) - return true -} - -// assignTask sends the index task to the IndexNode, it has a timeout interval, if the IndexNode doesn't respond within -// the interval, it is considered that the task sending failed. -func (ib *indexBuilder) assignTask(builderClient types.IndexNodeClient, req *indexpb.CreateJobRequest) error { - ctx, cancel := context.WithTimeout(context.Background(), reqTimeoutInterval) - defer cancel() - resp, err := builderClient.CreateJob(ctx, req) - if err == nil { - err = merr.Error(resp) - } - if err != nil { - log.Error("IndexCoord assignmentTasksLoop builderClient.CreateIndex failed", zap.Error(err)) - return err - } - - return nil -} - -func (ib *indexBuilder) nodeDown(nodeID UniqueID) { - defer ib.notify() - - metas := ib.meta.indexMeta.GetMetasByNodeID(nodeID) - - ib.taskMutex.Lock() - defer ib.taskMutex.Unlock() - - for _, meta := range metas { - if ib.tasks[meta.BuildID] != indexTaskDone { - ib.tasks[meta.BuildID] = indexTaskRetry - } - } -} diff --git a/internal/datacoord/index_builder_test.go b/internal/datacoord/index_builder_test.go deleted file mode 100644 index 9488c70f5e..0000000000 --- a/internal/datacoord/index_builder_test.go +++ /dev/null @@ -1,1611 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package datacoord - -import ( - "context" - "testing" - "time" - - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "google.golang.org/grpc" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/metastore" - catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks" - "github.com/milvus-io/milvus/internal/metastore/model" - "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/types" - mclient "github.com/milvus-io/milvus/internal/util/mock" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" -) - -var ( - collID = UniqueID(100) - partID = UniqueID(200) - indexID = UniqueID(300) - fieldID = UniqueID(400) - indexName = "_default_idx" - segID = UniqueID(500) - buildID = UniqueID(600) - nodeID = UniqueID(700) -) - -func createMetaTable(catalog metastore.DataCoordCatalog) *meta { - segIndeMeta := &indexMeta{ - catalog: catalog, - indexes: map[UniqueID]map[UniqueID]*model.Index{ - collID: { - indexID: { - TenantID: "", - CollectionID: collID, - FieldID: fieldID, - IndexID: indexID, - IndexName: indexName, - IsDeleted: false, - CreateTime: 1, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: "128", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: common.MetricTypeKey, - Value: "L2", - }, - }, - }, - }, - }, - segmentIndexes: map[UniqueID]map[UniqueID]*model.SegmentIndex{ - segID: { - indexID: { - SegmentID: segID, - CollectionID: collID, - PartitionID: partID, - NumRows: 1025, - IndexID: indexID, - BuildID: buildID, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 0, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 1: { - indexID: { - SegmentID: segID + 1, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 1, - NodeID: nodeID, - IndexVersion: 1, - IndexState: commonpb.IndexState_InProgress, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 2: { - indexID: { - SegmentID: segID + 2, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 2, - NodeID: nodeID, - IndexVersion: 1, - IndexState: commonpb.IndexState_InProgress, - FailReason: "", - IsDeleted: true, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 3: { - indexID: { - SegmentID: segID + 3, - CollectionID: collID, - PartitionID: partID, - NumRows: 500, - IndexID: indexID, - BuildID: buildID + 3, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 4: { - indexID: { - SegmentID: segID + 4, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 4, - NodeID: nodeID, - IndexVersion: 1, - IndexState: commonpb.IndexState_Finished, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 5: { - indexID: { - SegmentID: segID + 5, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 5, - NodeID: 0, - IndexVersion: 1, - IndexState: commonpb.IndexState_Finished, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 6: { - indexID: { - SegmentID: segID + 6, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 6, - NodeID: 0, - IndexVersion: 1, - IndexState: commonpb.IndexState_Finished, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 7: { - indexID: { - SegmentID: segID + 7, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 7, - NodeID: 0, - IndexVersion: 1, - IndexState: commonpb.IndexState_Failed, - FailReason: "error", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 8: { - indexID: { - SegmentID: segID + 8, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 8, - NodeID: nodeID + 1, - IndexVersion: 1, - IndexState: commonpb.IndexState_InProgress, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 9: { - indexID: { - SegmentID: segID + 9, - CollectionID: collID, - PartitionID: partID, - NumRows: 500, - IndexID: indexID, - BuildID: buildID + 9, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - segID + 10: { - indexID: { - SegmentID: segID + 10, - CollectionID: collID, - PartitionID: partID, - NumRows: 500, - IndexID: indexID, - BuildID: buildID + 10, - NodeID: nodeID, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - }, - buildID2SegmentIndex: map[UniqueID]*model.SegmentIndex{ - buildID: { - SegmentID: segID, - CollectionID: collID, - PartitionID: partID, - NumRows: 1025, - IndexID: indexID, - BuildID: buildID, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 0, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 1: { - SegmentID: segID + 1, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 1, - NodeID: nodeID, - IndexVersion: 1, - IndexState: commonpb.IndexState_InProgress, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 2: { - SegmentID: segID + 2, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 2, - NodeID: nodeID, - IndexVersion: 1, - IndexState: commonpb.IndexState_InProgress, - FailReason: "", - IsDeleted: true, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 3: { - SegmentID: segID + 3, - CollectionID: collID, - PartitionID: partID, - NumRows: 500, - IndexID: indexID, - BuildID: buildID + 3, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 4: { - SegmentID: segID + 4, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 4, - NodeID: nodeID, - IndexVersion: 1, - IndexState: commonpb.IndexState_Finished, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 5: { - SegmentID: segID + 5, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 5, - NodeID: 0, - IndexVersion: 1, - IndexState: commonpb.IndexState_Finished, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 6: { - SegmentID: segID + 6, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 6, - NodeID: 0, - IndexVersion: 1, - IndexState: commonpb.IndexState_Finished, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 7: { - SegmentID: segID + 7, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 7, - NodeID: 0, - IndexVersion: 1, - IndexState: commonpb.IndexState_Failed, - FailReason: "error", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 8: { - SegmentID: segID + 8, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 8, - NodeID: nodeID + 1, - IndexVersion: 1, - IndexState: commonpb.IndexState_InProgress, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 9: { - SegmentID: segID + 9, - CollectionID: collID, - PartitionID: partID, - NumRows: 500, - IndexID: indexID, - BuildID: buildID + 9, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - buildID + 10: { - SegmentID: segID + 10, - CollectionID: collID, - PartitionID: partID, - NumRows: 500, - IndexID: indexID, - BuildID: buildID + 10, - NodeID: nodeID, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 1111, - IndexFileKeys: nil, - IndexSize: 1, - }, - }, - } - - return &meta{ - indexMeta: segIndeMeta, - catalog: catalog, - segments: &SegmentsInfo{ - segments: map[UniqueID]*SegmentInfo{ - segID: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1025, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 1: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 1, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1026, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 2: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 2, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1026, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 3: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 3, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 500, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 4: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 4, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1026, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 5: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 5, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1026, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 6: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 6, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1026, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 7: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 7, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1026, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 8: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 8, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 1026, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 9: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 9, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 500, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - segID + 10: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID + 10, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: 500, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - }, - }, - } -} - -func TestIndexBuilder(t *testing.T) { - var ( - collID = UniqueID(100) - partID = UniqueID(200) - indexID = UniqueID(300) - segID = UniqueID(500) - buildID = UniqueID(600) - nodeID = UniqueID(700) - ) - - paramtable.Init() - ctx := context.Background() - catalog := catalogmocks.NewDataCoordCatalog(t) - catalog.On("CreateSegmentIndex", - mock.Anything, - mock.Anything, - ).Return(nil) - catalog.On("AlterSegmentIndexes", - mock.Anything, - mock.Anything, - ).Return(nil) - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything). - Return(&indexpb.GetJobStatsResponse{ - Status: merr.Success(), - TotalJobNum: 1, - EnqueueJobNum: 0, - InProgressJobNum: 1, - TaskSlots: 1, - JobInfos: []*indexpb.JobInfo{ - { - NumRows: 1024, - Dim: 128, - StartTime: 1, - EndTime: 10, - PodID: 1, - }, - }, - }, nil) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.QueryJobsRequest, option ...grpc.CallOption) (*indexpb.QueryJobsResponse, error) { - indexInfos := make([]*indexpb.IndexTaskInfo, 0) - for _, buildID := range in.BuildIDs { - indexInfos = append(indexInfos, &indexpb.IndexTaskInfo{ - BuildID: buildID, - State: commonpb.IndexState_Finished, - IndexFileKeys: []string{"file1", "file2"}, - }) - } - return &indexpb.QueryJobsResponse{ - Status: merr.Success(), - ClusterID: in.ClusterID, - IndexInfos: indexInfos, - }, nil - }) - - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(merr.Success(), nil) - - ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Success(), nil) - mt := createMetaTable(catalog) - nodeManager := &IndexNodeManager{ - ctx: ctx, - nodeClients: map[UniqueID]types.IndexNodeClient{ - 4: ic, - }, - } - chunkManager := &mocks.ChunkManager{} - chunkManager.EXPECT().RootPath().Return("root") - - handler := NewNMockHandler(t) - handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ - ID: collID, - Schema: &schemapb.CollectionSchema{ - Name: "coll", - Fields: []*schemapb.FieldSchema{ - { - FieldID: fieldID, - Name: "vec", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "128", - }, - }, - }, - }, - EnableDynamicField: false, - Properties: nil, - }, - }, nil) - - ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager, newIndexEngineVersionManager(), handler) - - assert.Equal(t, 6, len(ib.tasks)) - assert.Equal(t, indexTaskInit, ib.tasks[buildID]) - assert.Equal(t, indexTaskInProgress, ib.tasks[buildID+1]) - // buildID+2 will be filter by isDeleted - assert.Equal(t, indexTaskInit, ib.tasks[buildID+3]) - assert.Equal(t, indexTaskInProgress, ib.tasks[buildID+8]) - assert.Equal(t, indexTaskInit, ib.tasks[buildID+9]) - assert.Equal(t, indexTaskInit, ib.tasks[buildID+10]) - - ib.scheduleDuration = time.Millisecond * 500 - ib.Start() - - t.Run("enqueue", func(t *testing.T) { - segIdx := &model.SegmentIndex{ - SegmentID: segID + 10, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 10, - NodeID: 0, - IndexVersion: 0, - IndexState: 0, - FailReason: "", - IsDeleted: false, - CreateTime: 0, - IndexFileKeys: nil, - IndexSize: 0, - } - err := ib.meta.indexMeta.AddSegmentIndex(segIdx) - assert.NoError(t, err) - ib.enqueue(buildID + 10) - }) - - t.Run("node down", func(t *testing.T) { - ib.nodeDown(nodeID) - }) - - for { - ib.taskMutex.RLock() - if len(ib.tasks) == 0 { - break - } - ib.taskMutex.RUnlock() - } - ib.Stop() -} - -func TestIndexBuilder_Error(t *testing.T) { - paramtable.Init() - - sc := catalogmocks.NewDataCoordCatalog(t) - sc.On("AlterSegmentIndexes", - mock.Anything, - mock.Anything, - ).Return(nil) - ec := catalogmocks.NewDataCoordCatalog(t) - ec.On("AlterSegmentIndexes", - mock.Anything, - mock.Anything, - ).Return(errors.New("fail")) - - chunkManager := &mocks.ChunkManager{} - chunkManager.EXPECT().RootPath().Return("root") - - handler := NewNMockHandler(t) - handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ - ID: collID, - Schema: &schemapb.CollectionSchema{ - Name: "coll", - Fields: []*schemapb.FieldSchema{ - { - FieldID: fieldID, - Name: "vec", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "128", - }, - }, - }, - }, - EnableDynamicField: false, - Properties: nil, - }, - }, nil) - - ib := &indexBuilder{ - ctx: context.Background(), - tasks: map[int64]indexTaskState{ - buildID: indexTaskInit, - }, - meta: createMetaTable(ec), - chunkManager: chunkManager, - indexEngineVersionManager: newIndexEngineVersionManager(), - handler: handler, - } - - t.Run("meta not exist", func(t *testing.T) { - ib.tasks[buildID+100] = indexTaskInit - ib.process(buildID + 100) - - _, ok := ib.tasks[buildID+100] - assert.False(t, ok) - }) - - t.Run("finish few rows task fail", func(t *testing.T) { - ib.tasks[buildID+9] = indexTaskInit - ib.process(buildID + 9) - - state, ok := ib.tasks[buildID+9] - assert.True(t, ok) - assert.Equal(t, indexTaskInit, state) - }) - - t.Run("peek client fail", func(t *testing.T) { - ib.tasks[buildID] = indexTaskInit - ib.nodeManager = &IndexNodeManager{nodeClients: map[UniqueID]types.IndexNodeClient{}} - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskInit, state) - }) - - t.Run("update version fail", func(t *testing.T) { - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{1: &mclient.GrpcIndexNodeClient{Err: nil}}, - } - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskInit, state) - }) - - t.Run("no need to build index but update catalog failed", func(t *testing.T) { - ib.meta.catalog = ec - ib.meta.indexMeta.indexes[collID][indexID].IsDeleted = true - ib.tasks[buildID] = indexTaskInit - ok := ib.process(buildID) - assert.False(t, ok) - - _, ok = ib.tasks[buildID] - assert.True(t, ok) - }) - - t.Run("init no need to build index", func(t *testing.T) { - ib.meta.indexMeta.catalog = sc - ib.meta.catalog = sc - - ib.meta.indexMeta.indexes[collID][indexID].IsDeleted = true - ib.tasks[buildID] = indexTaskInit - ib.process(buildID) - - _, ok := ib.tasks[buildID] - assert.False(t, ok) - ib.meta.indexMeta.indexes[collID][indexID].IsDeleted = false - }) - - t.Run("assign task error", func(t *testing.T) { - paramtable.Get().Save(Params.CommonCfg.StorageType.Key, "local") - ib.tasks[buildID] = indexTaskInit - ib.meta.indexMeta.catalog = sc - ib.meta.catalog = sc - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error")) - ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{ - Status: merr.Success(), - TaskSlots: 1, - }, nil) - - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - 1: ic, - }, - } - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) - t.Run("assign task fail", func(t *testing.T) { - paramtable.Get().Save(Params.CommonCfg.StorageType.Key, "local") - ib.meta.indexMeta.catalog = sc - ib.meta.catalog = sc - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything).Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock fail", - }, nil) - ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{ - Status: merr.Success(), - TaskSlots: 1, - }, nil) - - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - 1: ic, - }, - } - ib.tasks[buildID] = indexTaskInit - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) - - t.Run("drop job error", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = sc - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything).Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, errors.New("error")) - - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - nodeID: ic, - }, - } - ib.tasks[buildID] = indexTaskDone - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskDone, state) - - ib.tasks[buildID] = indexTaskRetry - ib.process(buildID) - - state, ok = ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) - - t.Run("drop job fail", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = sc - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything).Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock fail", - }, nil) - - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - nodeID: ic, - }, - } - ib.tasks[buildID] = indexTaskDone - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskDone, state) - - ib.tasks[buildID] = indexTaskRetry - ib.process(buildID) - - state, ok = ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) - - t.Run("get state error", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = sc - ib.meta.indexMeta.catalog = sc - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error")) - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - nodeID: ic, - }, - } - - ib.tasks[buildID] = indexTaskInProgress - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) - - t.Run("get state fail", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = sc - ib.meta.indexMeta.catalog = sc - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_BuildIndexError, - Reason: "mock fail", - }, - }, nil) - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - nodeID: ic, - }, - } - - ib.tasks[buildID] = indexTaskInProgress - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) - - t.Run("finish task fail", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = ec - ib.meta.indexMeta.catalog = ec - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ - Status: merr.Success(), - IndexInfos: []*indexpb.IndexTaskInfo{ - { - BuildID: buildID, - State: commonpb.IndexState_Finished, - IndexFileKeys: []string{"file1", "file2"}, - SerializedSize: 1024, - FailReason: "", - }, - }, - }, nil) - - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - nodeID: ic, - }, - } - - ib.tasks[buildID] = indexTaskInProgress - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskInProgress, state) - }) - - t.Run("task still in progress", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = ec - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ - Status: merr.Success(), - IndexInfos: []*indexpb.IndexTaskInfo{ - { - BuildID: buildID, - State: commonpb.IndexState_InProgress, - IndexFileKeys: nil, - SerializedSize: 0, - FailReason: "", - }, - }, - }, nil) - - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - nodeID: ic, - }, - } - - ib.tasks[buildID] = indexTaskInProgress - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskInProgress, state) - }) - - t.Run("indexNode has no task", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = sc - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ - Status: merr.Success(), - IndexInfos: nil, - }, nil) - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{ - nodeID: ic, - }, - } - - ib.tasks[buildID] = indexTaskInProgress - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) - - t.Run("node not exist", func(t *testing.T) { - ib.meta.indexMeta.buildID2SegmentIndex[buildID].NodeID = nodeID - ib.meta.catalog = sc - ib.nodeManager = &IndexNodeManager{ - ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNodeClient{}, - } - - ib.tasks[buildID] = indexTaskInProgress - ib.process(buildID) - - state, ok := ib.tasks[buildID] - assert.True(t, ok) - assert.Equal(t, indexTaskRetry, state) - }) -} - -func TestIndexBuilderV2(t *testing.T) { - var ( - collID = UniqueID(100) - partID = UniqueID(200) - indexID = UniqueID(300) - segID = UniqueID(500) - buildID = UniqueID(600) - nodeID = UniqueID(700) - ) - - paramtable.Init() - paramtable.Get().CommonCfg.EnableStorageV2.SwapTempValue("true") - defer paramtable.Get().CommonCfg.EnableStorageV2.SwapTempValue("false") - ctx := context.Background() - catalog := catalogmocks.NewDataCoordCatalog(t) - catalog.On("CreateSegmentIndex", - mock.Anything, - mock.Anything, - ).Return(nil) - catalog.On("AlterSegmentIndexes", - mock.Anything, - mock.Anything, - ).Return(nil) - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything). - Return(&indexpb.GetJobStatsResponse{ - Status: merr.Success(), - TotalJobNum: 1, - EnqueueJobNum: 0, - InProgressJobNum: 1, - TaskSlots: 1, - JobInfos: []*indexpb.JobInfo{ - { - NumRows: 1024, - Dim: 128, - StartTime: 1, - EndTime: 10, - PodID: 1, - }, - }, - }, nil) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.QueryJobsRequest, option ...grpc.CallOption) (*indexpb.QueryJobsResponse, error) { - indexInfos := make([]*indexpb.IndexTaskInfo, 0) - for _, buildID := range in.BuildIDs { - indexInfos = append(indexInfos, &indexpb.IndexTaskInfo{ - BuildID: buildID, - State: commonpb.IndexState_Finished, - IndexFileKeys: []string{"file1", "file2"}, - }) - } - return &indexpb.QueryJobsResponse{ - Status: merr.Success(), - ClusterID: in.ClusterID, - IndexInfos: indexInfos, - }, nil - }) - - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(merr.Success(), nil) - - ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Success(), nil) - mt := createMetaTable(catalog) - nodeManager := &IndexNodeManager{ - ctx: ctx, - nodeClients: map[UniqueID]types.IndexNodeClient{ - 4: ic, - }, - } - chunkManager := &mocks.ChunkManager{} - chunkManager.EXPECT().RootPath().Return("root") - - handler := NewNMockHandler(t) - handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ - ID: collID, - Schema: &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - {FieldID: fieldID, Name: "vec", TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "10"}}}, - }, - }, - }, nil) - - ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager, newIndexEngineVersionManager(), handler) - - assert.Equal(t, 6, len(ib.tasks)) - assert.Equal(t, indexTaskInit, ib.tasks[buildID]) - assert.Equal(t, indexTaskInProgress, ib.tasks[buildID+1]) - // buildID+2 will be filter by isDeleted - assert.Equal(t, indexTaskInit, ib.tasks[buildID+3]) - assert.Equal(t, indexTaskInProgress, ib.tasks[buildID+8]) - assert.Equal(t, indexTaskInit, ib.tasks[buildID+9]) - assert.Equal(t, indexTaskInit, ib.tasks[buildID+10]) - - ib.scheduleDuration = time.Millisecond * 500 - ib.Start() - - t.Run("enqueue", func(t *testing.T) { - segIdx := &model.SegmentIndex{ - SegmentID: segID + 10, - CollectionID: collID, - PartitionID: partID, - NumRows: 1026, - IndexID: indexID, - BuildID: buildID + 10, - NodeID: 0, - IndexVersion: 0, - IndexState: 0, - FailReason: "", - IsDeleted: false, - CreateTime: 0, - IndexFileKeys: nil, - IndexSize: 0, - } - err := ib.meta.indexMeta.AddSegmentIndex(segIdx) - assert.NoError(t, err) - ib.enqueue(buildID + 10) - }) - - t.Run("node down", func(t *testing.T) { - ib.nodeDown(nodeID) - }) - - for { - ib.taskMutex.RLock() - if len(ib.tasks) == 0 { - break - } - ib.taskMutex.RUnlock() - } - ib.Stop() -} - -func TestVecIndexWithOptionalScalarField(t *testing.T) { - var ( - collID = UniqueID(100) - partID = UniqueID(200) - indexID = UniqueID(300) - segID = UniqueID(500) - buildID = UniqueID(600) - nodeID = UniqueID(700) - partitionKeyID = UniqueID(800) - ) - - paramtable.Init() - ctx := context.Background() - minNumberOfRowsToBuild := paramtable.Get().DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() + 1 - - catalog := catalogmocks.NewDataCoordCatalog(t) - catalog.On("CreateSegmentIndex", - mock.Anything, - mock.Anything, - ).Return(nil) - catalog.On("AlterSegmentIndexes", - mock.Anything, - mock.Anything, - ).Return(nil) - - ic := mocks.NewMockIndexNodeClient(t) - ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything). - Return(&indexpb.GetJobStatsResponse{ - Status: merr.Success(), - TotalJobNum: 0, - EnqueueJobNum: 0, - InProgressJobNum: 0, - TaskSlots: 1, - JobInfos: []*indexpb.JobInfo{}, - }, nil) - ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.QueryJobsRequest, option ...grpc.CallOption) (*indexpb.QueryJobsResponse, error) { - indexInfos := make([]*indexpb.IndexTaskInfo, 0) - for _, buildID := range in.BuildIDs { - indexInfos = append(indexInfos, &indexpb.IndexTaskInfo{ - BuildID: buildID, - State: commonpb.IndexState_Finished, - IndexFileKeys: []string{"file1", "file2"}, - }) - } - return &indexpb.QueryJobsResponse{ - Status: merr.Success(), - ClusterID: in.ClusterID, - IndexInfos: indexInfos, - }, nil - }) - - ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Success(), nil) - - mt := meta{ - catalog: catalog, - collections: map[int64]*collectionInfo{ - collID: { - ID: collID, - Schema: &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - FieldID: fieldID, - Name: "vec", - DataType: schemapb.DataType_FloatVector, - }, - { - FieldID: partitionKeyID, - Name: "scalar", - DataType: schemapb.DataType_VarChar, - IsPartitionKey: true, - }, - }, - }, - CreatedAt: 0, - }, - }, - - indexMeta: &indexMeta{ - catalog: catalog, - indexes: map[UniqueID]map[UniqueID]*model.Index{ - collID: { - indexID: { - TenantID: "", - CollectionID: collID, - FieldID: fieldID, - IndexID: indexID, - IndexName: indexName, - IsDeleted: false, - CreateTime: 1, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: "128", - }, - }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: common.MetricTypeKey, - Value: "L2", - }, - { - Key: common.IndexTypeKey, - Value: indexparamcheck.IndexHNSW, - }, - }, - }, - }, - }, - segmentIndexes: map[UniqueID]map[UniqueID]*model.SegmentIndex{ - segID: { - indexID: { - SegmentID: segID, - CollectionID: collID, - PartitionID: partID, - NumRows: minNumberOfRowsToBuild, - IndexID: indexID, - BuildID: buildID, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 0, - IndexFileKeys: nil, - IndexSize: 0, - }, - }, - }, - buildID2SegmentIndex: map[UniqueID]*model.SegmentIndex{ - buildID: { - SegmentID: segID, - CollectionID: collID, - PartitionID: partID, - NumRows: minNumberOfRowsToBuild, - IndexID: indexID, - BuildID: buildID, - NodeID: 0, - IndexVersion: 0, - IndexState: commonpb.IndexState_Unissued, - FailReason: "", - IsDeleted: false, - CreateTime: 0, - IndexFileKeys: nil, - IndexSize: 0, - }, - }, - }, - segments: &SegmentsInfo{ - segments: map[UniqueID]*SegmentInfo{ - segID: { - SegmentInfo: &datapb.SegmentInfo{ - ID: segID, - CollectionID: collID, - PartitionID: partID, - InsertChannel: "", - NumOfRows: minNumberOfRowsToBuild, - State: commonpb.SegmentState_Flushed, - MaxRowNum: 65536, - LastExpireTime: 10, - }, - }, - }, - }, - } - - nodeManager := &IndexNodeManager{ - ctx: ctx, - nodeClients: map[UniqueID]types.IndexNodeClient{ - 1: ic, - }, - } - - cm := &mocks.ChunkManager{} - cm.EXPECT().RootPath().Return("root") - - waitTaskDoneFunc := func(builder *indexBuilder) { - for { - builder.taskMutex.RLock() - if len(builder.tasks) == 0 { - builder.taskMutex.RUnlock() - break - } - builder.taskMutex.RUnlock() - } - - assert.Zero(t, len(builder.tasks)) - } - - resetMetaFunc := func() { - mt.indexMeta.buildID2SegmentIndex[buildID].IndexState = commonpb.IndexState_Unissued - mt.indexMeta.segmentIndexes[segID][indexID].IndexState = commonpb.IndexState_Unissued - mt.indexMeta.indexes[collID][indexID].IndexParams[1].Value = indexparamcheck.IndexHNSW - mt.collections[collID].Schema.Fields[1].IsPartitionKey = true - mt.collections[collID].Schema.Fields[1].DataType = schemapb.DataType_VarChar - } - - handler := NewNMockHandler(t) - handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ - ID: collID, - Schema: &schemapb.CollectionSchema{ - Name: "coll", - Fields: []*schemapb.FieldSchema{ - { - FieldID: fieldID, - Name: "vec", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "128", - }, - }, - }, - }, - EnableDynamicField: false, - Properties: nil, - }, - }, nil) - - paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true") - defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false") - ib := newIndexBuilder(ctx, &mt, nodeManager, cm, newIndexEngineVersionManager(), handler) - - t.Run("success to get opt field on startup", func(t *testing.T) { - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - assert.NotZero(t, len(in.OptionalScalarFields), "optional scalar field should be set") - return merr.Success(), nil - }).Once() - assert.Equal(t, 1, len(ib.tasks)) - assert.Equal(t, indexTaskInit, ib.tasks[buildID]) - - ib.scheduleDuration = time.Millisecond * 500 - ib.Start() - waitTaskDoneFunc(ib) - resetMetaFunc() - }) - - segIdx := &model.SegmentIndex{ - SegmentID: segID, - CollectionID: collID, - PartitionID: partID, - NumRows: minNumberOfRowsToBuild, - IndexID: indexID, - BuildID: buildID, - NodeID: 0, - IndexVersion: 0, - IndexState: 0, - FailReason: "", - IsDeleted: false, - CreateTime: 0, - IndexFileKeys: nil, - IndexSize: 0, - } - - t.Run("enqueue varchar", func(t *testing.T) { - mt.collections[collID].Schema.Fields[1].DataType = schemapb.DataType_VarChar - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - assert.NotZero(t, len(in.OptionalScalarFields), "optional scalar field should be set") - return merr.Success(), nil - }).Once() - err := ib.meta.indexMeta.AddSegmentIndex(segIdx) - assert.NoError(t, err) - ib.enqueue(buildID) - waitTaskDoneFunc(ib) - resetMetaFunc() - }) - - t.Run("enqueue string", func(t *testing.T) { - mt.collections[collID].Schema.Fields[1].DataType = schemapb.DataType_String - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - assert.NotZero(t, len(in.OptionalScalarFields), "optional scalar field should be set") - return merr.Success(), nil - }).Once() - err := ib.meta.indexMeta.AddSegmentIndex(segIdx) - assert.NoError(t, err) - ib.enqueue(buildID) - waitTaskDoneFunc(ib) - resetMetaFunc() - }) - - // should still be able to build vec index when opt field is not set - t.Run("enqueue returns empty optional field when cfg disable", func(t *testing.T) { - paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false") - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - assert.Zero(t, len(in.OptionalScalarFields), "optional scalar field should be set") - return merr.Success(), nil - }).Once() - err := ib.meta.indexMeta.AddSegmentIndex(segIdx) - assert.NoError(t, err) - ib.enqueue(buildID) - waitTaskDoneFunc(ib) - resetMetaFunc() - }) - - t.Run("enqueue returns empty optional field when the data type is not STRING or VARCHAR", func(t *testing.T) { - paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true") - for _, dataType := range []schemapb.DataType{ - schemapb.DataType_Bool, - schemapb.DataType_Int8, - schemapb.DataType_Int16, - schemapb.DataType_Int32, - schemapb.DataType_Int64, - schemapb.DataType_Float, - schemapb.DataType_Double, - schemapb.DataType_Array, - schemapb.DataType_JSON, - } { - mt.collections[collID].Schema.Fields[1].DataType = dataType - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - assert.Zero(t, len(in.OptionalScalarFields), "optional scalar field should be set") - return merr.Success(), nil - }).Once() - err := ib.meta.indexMeta.AddSegmentIndex(segIdx) - assert.NoError(t, err) - ib.enqueue(buildID) - waitTaskDoneFunc(ib) - resetMetaFunc() - } - }) - - t.Run("enqueue returns empty optional field when no partition key", func(t *testing.T) { - paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true") - mt.collections[collID].Schema.Fields[1].IsPartitionKey = false - ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( - func(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - assert.Zero(t, len(in.OptionalScalarFields), "optional scalar field should be set") - return merr.Success(), nil - }).Once() - err := ib.meta.indexMeta.AddSegmentIndex(segIdx) - assert.NoError(t, err) - ib.enqueue(buildID) - waitTaskDoneFunc(ib) - resetMetaFunc() - }) - - ib.nodeDown(nodeID) - ib.Stop() -} diff --git a/internal/datacoord/index_meta.go b/internal/datacoord/index_meta.go index b3ad02de01..f257f472b6 100644 --- a/internal/datacoord/index_meta.go +++ b/internal/datacoord/index_meta.go @@ -653,18 +653,17 @@ func (m *indexMeta) IsIndexExist(collID, indexID UniqueID) bool { } // UpdateVersion updates the version and nodeID of the index meta, whenever the task is built once, the version will be updated once. -func (m *indexMeta) UpdateVersion(buildID UniqueID, nodeID UniqueID) error { +func (m *indexMeta) UpdateVersion(buildID UniqueID) error { m.Lock() defer m.Unlock() - log.Debug("IndexCoord metaTable UpdateVersion receive", zap.Int64("buildID", buildID), zap.Int64("nodeID", nodeID)) + log.Debug("IndexCoord metaTable UpdateVersion receive", zap.Int64("buildID", buildID)) segIdx, ok := m.buildID2SegmentIndex[buildID] if !ok { return fmt.Errorf("there is no index with buildID: %d", buildID) } updateFunc := func(segIdx *model.SegmentIndex) error { - segIdx.NodeID = nodeID segIdx.IndexVersion++ return m.alterSegmentIndexes([]*model.SegmentIndex{segIdx}) } @@ -728,7 +727,7 @@ func (m *indexMeta) DeleteTask(buildID int64) error { } // BuildIndex set the index state to be InProgress. It means IndexNode is building the index. -func (m *indexMeta) BuildIndex(buildID UniqueID) error { +func (m *indexMeta) BuildIndex(buildID, nodeID UniqueID) error { m.Lock() defer m.Unlock() @@ -738,6 +737,7 @@ func (m *indexMeta) BuildIndex(buildID UniqueID) error { } updateFunc := func(segIdx *model.SegmentIndex) error { + segIdx.NodeID = nodeID segIdx.IndexState = commonpb.IndexState_InProgress err := m.alterSegmentIndexes([]*model.SegmentIndex{segIdx}) @@ -828,7 +828,7 @@ func (m *indexMeta) RemoveIndex(collID, indexID UniqueID) error { return nil } -func (m *indexMeta) CleanSegmentIndex(buildID UniqueID) (bool, *model.SegmentIndex) { +func (m *indexMeta) CheckCleanSegmentIndex(buildID UniqueID) (bool, *model.SegmentIndex) { m.RLock() defer m.RUnlock() diff --git a/internal/datacoord/index_meta_test.go b/internal/datacoord/index_meta_test.go index 806e841c94..196071abd9 100644 --- a/internal/datacoord/index_meta_test.go +++ b/internal/datacoord/index_meta_test.go @@ -693,7 +693,8 @@ func TestMeta_MarkIndexAsDeleted(t *testing.T) { } func TestMeta_GetSegmentIndexes(t *testing.T) { - m := createMetaTable(&datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}) + catalog := &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)} + m := createMeta(catalog, nil, createIndexMeta(catalog)) t.Run("success", func(t *testing.T) { segIndexes := m.indexMeta.getSegmentIndexes(segID) @@ -1075,18 +1076,18 @@ func TestMeta_UpdateVersion(t *testing.T) { ).Return(errors.New("fail")) t.Run("success", func(t *testing.T) { - err := m.UpdateVersion(buildID, nodeID) + err := m.UpdateVersion(buildID) assert.NoError(t, err) }) t.Run("fail", func(t *testing.T) { m.catalog = ec - err := m.UpdateVersion(buildID, nodeID) + err := m.UpdateVersion(buildID) assert.Error(t, err) }) t.Run("not exist", func(t *testing.T) { - err := m.UpdateVersion(buildID+1, nodeID) + err := m.UpdateVersion(buildID + 1) assert.Error(t, err) }) } @@ -1143,18 +1144,18 @@ func TestMeta_BuildIndex(t *testing.T) { ).Return(errors.New("fail")) t.Run("success", func(t *testing.T) { - err := m.BuildIndex(buildID) + err := m.BuildIndex(buildID, nodeID) assert.NoError(t, err) }) t.Run("fail", func(t *testing.T) { m.catalog = ec - err := m.BuildIndex(buildID) + err := m.BuildIndex(buildID, nodeID) assert.Error(t, err) }) t.Run("not exist", func(t *testing.T) { - err := m.BuildIndex(buildID + 1) + err := m.BuildIndex(buildID+1, nodeID) assert.Error(t, err) }) } @@ -1330,7 +1331,8 @@ func TestRemoveSegmentIndex(t *testing.T) { } func TestIndexMeta_GetUnindexedSegments(t *testing.T) { - m := createMetaTable(&datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}) + catalog := &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)} + m := createMeta(catalog, nil, createIndexMeta(catalog)) // normal case segmentIDs := make([]int64, 0, 11) diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index 1db44438af..f72a0c3846 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -48,8 +48,6 @@ func (s *Server) serverID() int64 { } func (s *Server) startIndexService(ctx context.Context) { - s.indexBuilder.Start() - s.serverLoopWg.Add(1) go s.createIndexForSegmentLoop(ctx) } @@ -73,7 +71,13 @@ func (s *Server) createIndexForSegment(segment *SegmentInfo, indexID UniqueID) e if err = s.meta.indexMeta.AddSegmentIndex(segIndex); err != nil { return err } - s.indexBuilder.enqueue(buildID) + s.taskScheduler.enqueue(&indexBuildTask{ + buildID: buildID, + taskInfo: &indexpb.IndexTaskInfo{ + BuildID: buildID, + State: commonpb.IndexState_Unissued, + }, + }) return nil } diff --git a/internal/datacoord/indexnode_manager.go b/internal/datacoord/indexnode_manager.go index 4cd8560fba..890a9ed0e1 100644 --- a/internal/datacoord/indexnode_manager.go +++ b/internal/datacoord/indexnode_manager.go @@ -24,7 +24,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" @@ -33,6 +32,16 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) +type WorkerManager interface { + AddNode(nodeID UniqueID, address string) error + RemoveNode(nodeID UniqueID) + StoppingNode(nodeID UniqueID) + PickClient() (UniqueID, types.IndexNodeClient) + ClientSupportDisk() bool + GetAllClients() map[UniqueID]types.IndexNodeClient + GetClientByID(nodeID UniqueID) (types.IndexNodeClient, bool) +} + // IndexNodeManager is used to manage the client of IndexNode. type IndexNodeManager struct { nodeClients map[UniqueID]types.IndexNodeClient @@ -98,59 +107,55 @@ func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error { return nil } -// PeekClient peeks the client with the least load. -func (nm *IndexNodeManager) PeekClient(meta *model.SegmentIndex) (UniqueID, types.IndexNodeClient) { - allClients := nm.GetAllClients() - if len(allClients) == 0 { - log.Error("there is no IndexNode online") - return -1, nil - } +func (nm *IndexNodeManager) PickClient() (UniqueID, types.IndexNodeClient) { + nm.lock.Lock() + defer nm.lock.Unlock() // Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected ctx, cancel := context.WithCancel(nm.ctx) var ( - peekNodeID = UniqueID(0) - nodeMutex = lock.Mutex{} + pickNodeID = UniqueID(0) + nodeMutex = sync.Mutex{} wg = sync.WaitGroup{} ) - for nodeID, client := range allClients { - nodeID := nodeID - client := client - wg.Add(1) - go func() { - defer wg.Done() - resp, err := client.GetJobStats(ctx, &indexpb.GetJobStatsRequest{}) - if err != nil { - log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err)) - return - } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), - zap.String("reason", resp.GetStatus().GetReason())) - return - } - if resp.GetTaskSlots() > 0 { - nodeMutex.Lock() - defer nodeMutex.Unlock() - log.Info("peek client success", zap.Int64("nodeID", nodeID)) - if peekNodeID == 0 { - peekNodeID = nodeID + for nodeID, client := range nm.nodeClients { + if _, ok := nm.stoppingNodes[nodeID]; !ok { + nodeID := nodeID + client := client + wg.Add(1) + go func() { + defer wg.Done() + resp, err := client.GetJobStats(ctx, &indexpb.GetJobStatsRequest{}) + if err != nil { + log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err)) + return } - cancel() - // Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected - return - } - }() + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), + zap.String("reason", resp.GetStatus().GetReason())) + return + } + if resp.GetTaskSlots() > 0 { + nodeMutex.Lock() + defer nodeMutex.Unlock() + if pickNodeID == 0 { + pickNodeID = nodeID + } + cancel() + // Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected + return + } + }() + } } wg.Wait() cancel() - if peekNodeID != 0 { - log.Info("peek client success", zap.Int64("nodeID", peekNodeID)) - return peekNodeID, allClients[peekNodeID] + if pickNodeID != 0 { + log.Info("pick indexNode success", zap.Int64("nodeID", pickNodeID)) + return pickNodeID, nm.nodeClients[pickNodeID] } - log.RatedDebug(5, "peek client fail") return 0, nil } diff --git a/internal/datacoord/indexnode_manager_test.go b/internal/datacoord/indexnode_manager_test.go index c3948a040f..360953fea2 100644 --- a/internal/datacoord/indexnode_manager_test.go +++ b/internal/datacoord/indexnode_manager_test.go @@ -24,7 +24,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/types" @@ -34,9 +33,6 @@ import ( func TestIndexNodeManager_AddNode(t *testing.T) { nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc) - nodeID, client := nm.PeekClient(&model.SegmentIndex{}) - assert.Equal(t, int64(-1), nodeID) - assert.Nil(t, client) t.Run("success", func(t *testing.T) { err := nm.AddNode(1, "indexnode-1") @@ -49,7 +45,7 @@ func TestIndexNodeManager_AddNode(t *testing.T) { }) } -func TestIndexNodeManager_PeekClient(t *testing.T) { +func TestIndexNodeManager_PickClient(t *testing.T) { getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient { ic := mocks.NewMockIndexNodeClient(t) ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err) @@ -94,9 +90,9 @@ func TestIndexNodeManager_PeekClient(t *testing.T) { }, } - nodeID, client := nm.PeekClient(&model.SegmentIndex{}) + selectNodeID, client := nm.PickClient() assert.NotNil(t, client) - assert.Contains(t, []UniqueID{8, 9}, nodeID) + assert.Contains(t, []UniqueID{8, 9}, selectNodeID) }) } diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index f3bbe6b507..a600e82fa1 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -57,10 +57,15 @@ type CompactionMeta interface { SetSegmentCompacting(segmentID int64, compacting bool) CheckAndSetSegmentsCompacting(segmentIDs []int64) (bool, bool) CompleteCompactionMutation(plan *datapb.CompactionPlan, result *datapb.CompactionPlanResult) ([]*SegmentInfo, *segMetricMutation, error) + SaveCompactionTask(task *datapb.CompactionTask) error DropCompactionTask(task *datapb.CompactionTask) error GetCompactionTasks() map[int64][]*datapb.CompactionTask GetCompactionTasksByTriggerID(triggerID int64) []*datapb.CompactionTask + + GetIndexMeta() *indexMeta + GetAnalyzeMeta() *analyzeMeta + GetCompactionTaskMeta() *compactionTaskMeta } var _ CompactionMeta = (*meta)(nil) @@ -75,9 +80,22 @@ type meta struct { chunkManager storage.ChunkManager indexMeta *indexMeta + analyzeMeta *analyzeMeta compactionTaskMeta *compactionTaskMeta } +func (m *meta) GetIndexMeta() *indexMeta { + return m.indexMeta +} + +func (m *meta) GetAnalyzeMeta() *analyzeMeta { + return m.analyzeMeta +} + +func (m *meta) GetCompactionTaskMeta() *compactionTaskMeta { + return m.compactionTaskMeta +} + type channelCPs struct { lock.RWMutex checkpoints map[string]*msgpb.MsgPosition @@ -110,7 +128,12 @@ type collectionInfo struct { // NewMeta creates meta from provided `kv.TxnKV` func newMeta(ctx context.Context, catalog metastore.DataCoordCatalog, chunkManager storage.ChunkManager) (*meta, error) { - indexMeta, err := newIndexMeta(ctx, catalog) + im, err := newIndexMeta(ctx, catalog) + if err != nil { + return nil, err + } + + am, err := newAnalyzeMeta(ctx, catalog) if err != nil { return nil, err } @@ -119,14 +142,14 @@ func newMeta(ctx context.Context, catalog metastore.DataCoordCatalog, chunkManag if err != nil { return nil, err } - mt := &meta{ ctx: ctx, catalog: catalog, collections: make(map[UniqueID]*collectionInfo), segments: NewSegmentsInfo(), channelCPs: newChannelCps(), - indexMeta: indexMeta, + indexMeta: im, + analyzeMeta: am, chunkManager: chunkManager, compactionTaskMeta: ctm, } diff --git a/internal/datacoord/meta_test.go b/internal/datacoord/meta_test.go index 05a26b3d54..4924a4c8e7 100644 --- a/internal/datacoord/meta_test.go +++ b/internal/datacoord/meta_test.go @@ -71,6 +71,7 @@ func (suite *MetaReloadSuite) TestReloadFromKV() { suite.catalog.EXPECT().ListSegments(mock.Anything).Return(nil, errors.New("mock")) suite.catalog.EXPECT().ListIndexes(mock.Anything).Return([]*model.Index{}, nil) suite.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return([]*model.SegmentIndex{}, nil) + suite.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) suite.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) _, err := newMeta(ctx, suite.catalog, nil) @@ -84,6 +85,7 @@ func (suite *MetaReloadSuite) TestReloadFromKV() { suite.catalog.EXPECT().ListChannelCheckpoint(mock.Anything).Return(nil, errors.New("mock")) suite.catalog.EXPECT().ListIndexes(mock.Anything).Return([]*model.Index{}, nil) suite.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return([]*model.SegmentIndex{}, nil) + suite.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) suite.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) _, err := newMeta(ctx, suite.catalog, nil) @@ -94,6 +96,7 @@ func (suite *MetaReloadSuite) TestReloadFromKV() { defer suite.resetMock() suite.catalog.EXPECT().ListIndexes(mock.Anything).Return([]*model.Index{}, nil) suite.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return([]*model.SegmentIndex{}, nil) + suite.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil) suite.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil) suite.catalog.EXPECT().ListSegments(mock.Anything).Return([]*datapb.SegmentInfo{ { @@ -622,7 +625,7 @@ func TestMeta_Basic(t *testing.T) { }) t.Run("Test GetCollectionBinlogSize", func(t *testing.T) { - meta := createMetaTable(&datacoord.Catalog{}) + meta := createMeta(&datacoord.Catalog{}, nil, createIndexMeta(&datacoord.Catalog{})) ret := meta.GetCollectionIndexFilesSize() assert.Equal(t, uint64(0), ret) diff --git a/internal/datacoord/mock_compaction_meta.go b/internal/datacoord/mock_compaction_meta.go index 4f048ef1d3..a4e9173ca5 100644 --- a/internal/datacoord/mock_compaction_meta.go +++ b/internal/datacoord/mock_compaction_meta.go @@ -178,6 +178,92 @@ func (_c *MockCompactionMeta_DropCompactionTask_Call) RunAndReturn(run func(*dat return _c } +// GetAnalyzeMeta provides a mock function with given fields: +func (_m *MockCompactionMeta) GetAnalyzeMeta() *analyzeMeta { + ret := _m.Called() + + var r0 *analyzeMeta + if rf, ok := ret.Get(0).(func() *analyzeMeta); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*analyzeMeta) + } + } + + return r0 +} + +// MockCompactionMeta_GetAnalyzeMeta_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAnalyzeMeta' +type MockCompactionMeta_GetAnalyzeMeta_Call struct { + *mock.Call +} + +// GetAnalyzeMeta is a helper method to define mock.On call +func (_e *MockCompactionMeta_Expecter) GetAnalyzeMeta() *MockCompactionMeta_GetAnalyzeMeta_Call { + return &MockCompactionMeta_GetAnalyzeMeta_Call{Call: _e.mock.On("GetAnalyzeMeta")} +} + +func (_c *MockCompactionMeta_GetAnalyzeMeta_Call) Run(run func()) *MockCompactionMeta_GetAnalyzeMeta_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCompactionMeta_GetAnalyzeMeta_Call) Return(_a0 *analyzeMeta) *MockCompactionMeta_GetAnalyzeMeta_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCompactionMeta_GetAnalyzeMeta_Call) RunAndReturn(run func() *analyzeMeta) *MockCompactionMeta_GetAnalyzeMeta_Call { + _c.Call.Return(run) + return _c +} + +// GetCompactionTaskMeta provides a mock function with given fields: +func (_m *MockCompactionMeta) GetCompactionTaskMeta() *compactionTaskMeta { + ret := _m.Called() + + var r0 *compactionTaskMeta + if rf, ok := ret.Get(0).(func() *compactionTaskMeta); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*compactionTaskMeta) + } + } + + return r0 +} + +// MockCompactionMeta_GetCompactionTaskMeta_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompactionTaskMeta' +type MockCompactionMeta_GetCompactionTaskMeta_Call struct { + *mock.Call +} + +// GetCompactionTaskMeta is a helper method to define mock.On call +func (_e *MockCompactionMeta_Expecter) GetCompactionTaskMeta() *MockCompactionMeta_GetCompactionTaskMeta_Call { + return &MockCompactionMeta_GetCompactionTaskMeta_Call{Call: _e.mock.On("GetCompactionTaskMeta")} +} + +func (_c *MockCompactionMeta_GetCompactionTaskMeta_Call) Run(run func()) *MockCompactionMeta_GetCompactionTaskMeta_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCompactionMeta_GetCompactionTaskMeta_Call) Return(_a0 *compactionTaskMeta) *MockCompactionMeta_GetCompactionTaskMeta_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCompactionMeta_GetCompactionTaskMeta_Call) RunAndReturn(run func() *compactionTaskMeta) *MockCompactionMeta_GetCompactionTaskMeta_Call { + _c.Call.Return(run) + return _c +} + // GetCompactionTasks provides a mock function with given fields: func (_m *MockCompactionMeta) GetCompactionTasks() map[int64][]*datapb.CompactionTask { ret := _m.Called() @@ -309,6 +395,49 @@ func (_c *MockCompactionMeta_GetHealthySegment_Call) RunAndReturn(run func(int64 return _c } +// GetIndexMeta provides a mock function with given fields: +func (_m *MockCompactionMeta) GetIndexMeta() *indexMeta { + ret := _m.Called() + + var r0 *indexMeta + if rf, ok := ret.Get(0).(func() *indexMeta); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexMeta) + } + } + + return r0 +} + +// MockCompactionMeta_GetIndexMeta_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexMeta' +type MockCompactionMeta_GetIndexMeta_Call struct { + *mock.Call +} + +// GetIndexMeta is a helper method to define mock.On call +func (_e *MockCompactionMeta_Expecter) GetIndexMeta() *MockCompactionMeta_GetIndexMeta_Call { + return &MockCompactionMeta_GetIndexMeta_Call{Call: _e.mock.On("GetIndexMeta")} +} + +func (_c *MockCompactionMeta_GetIndexMeta_Call) Run(run func()) *MockCompactionMeta_GetIndexMeta_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCompactionMeta_GetIndexMeta_Call) Return(_a0 *indexMeta) *MockCompactionMeta_GetIndexMeta_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCompactionMeta_GetIndexMeta_Call) RunAndReturn(run func() *indexMeta) *MockCompactionMeta_GetIndexMeta_Call { + _c.Call.Return(run) + return _c +} + // GetSegment provides a mock function with given fields: segID func (_m *MockCompactionMeta) GetSegment(segID int64) *SegmentInfo { ret := _m.Called(segID) diff --git a/internal/datacoord/mock_worker_manager.go b/internal/datacoord/mock_worker_manager.go new file mode 100644 index 0000000000..6d2bc1ea79 --- /dev/null +++ b/internal/datacoord/mock_worker_manager.go @@ -0,0 +1,335 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package datacoord + +import ( + types "github.com/milvus-io/milvus/internal/types" + mock "github.com/stretchr/testify/mock" +) + +// MockWorkerManager is an autogenerated mock type for the WorkerManager type +type MockWorkerManager struct { + mock.Mock +} + +type MockWorkerManager_Expecter struct { + mock *mock.Mock +} + +func (_m *MockWorkerManager) EXPECT() *MockWorkerManager_Expecter { + return &MockWorkerManager_Expecter{mock: &_m.Mock} +} + +// AddNode provides a mock function with given fields: nodeID, address +func (_m *MockWorkerManager) AddNode(nodeID int64, address string) error { + ret := _m.Called(nodeID, address) + + var r0 error + if rf, ok := ret.Get(0).(func(int64, string) error); ok { + r0 = rf(nodeID, address) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockWorkerManager_AddNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddNode' +type MockWorkerManager_AddNode_Call struct { + *mock.Call +} + +// AddNode is a helper method to define mock.On call +// - nodeID int64 +// - address string +func (_e *MockWorkerManager_Expecter) AddNode(nodeID interface{}, address interface{}) *MockWorkerManager_AddNode_Call { + return &MockWorkerManager_AddNode_Call{Call: _e.mock.On("AddNode", nodeID, address)} +} + +func (_c *MockWorkerManager_AddNode_Call) Run(run func(nodeID int64, address string)) *MockWorkerManager_AddNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(string)) + }) + return _c +} + +func (_c *MockWorkerManager_AddNode_Call) Return(_a0 error) *MockWorkerManager_AddNode_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWorkerManager_AddNode_Call) RunAndReturn(run func(int64, string) error) *MockWorkerManager_AddNode_Call { + _c.Call.Return(run) + return _c +} + +// ClientSupportDisk provides a mock function with given fields: +func (_m *MockWorkerManager) ClientSupportDisk() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockWorkerManager_ClientSupportDisk_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClientSupportDisk' +type MockWorkerManager_ClientSupportDisk_Call struct { + *mock.Call +} + +// ClientSupportDisk is a helper method to define mock.On call +func (_e *MockWorkerManager_Expecter) ClientSupportDisk() *MockWorkerManager_ClientSupportDisk_Call { + return &MockWorkerManager_ClientSupportDisk_Call{Call: _e.mock.On("ClientSupportDisk")} +} + +func (_c *MockWorkerManager_ClientSupportDisk_Call) Run(run func()) *MockWorkerManager_ClientSupportDisk_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWorkerManager_ClientSupportDisk_Call) Return(_a0 bool) *MockWorkerManager_ClientSupportDisk_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWorkerManager_ClientSupportDisk_Call) RunAndReturn(run func() bool) *MockWorkerManager_ClientSupportDisk_Call { + _c.Call.Return(run) + return _c +} + +// GetAllClients provides a mock function with given fields: +func (_m *MockWorkerManager) GetAllClients() map[int64]types.IndexNodeClient { + ret := _m.Called() + + var r0 map[int64]types.IndexNodeClient + if rf, ok := ret.Get(0).(func() map[int64]types.IndexNodeClient); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]types.IndexNodeClient) + } + } + + return r0 +} + +// MockWorkerManager_GetAllClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAllClients' +type MockWorkerManager_GetAllClients_Call struct { + *mock.Call +} + +// GetAllClients is a helper method to define mock.On call +func (_e *MockWorkerManager_Expecter) GetAllClients() *MockWorkerManager_GetAllClients_Call { + return &MockWorkerManager_GetAllClients_Call{Call: _e.mock.On("GetAllClients")} +} + +func (_c *MockWorkerManager_GetAllClients_Call) Run(run func()) *MockWorkerManager_GetAllClients_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWorkerManager_GetAllClients_Call) Return(_a0 map[int64]types.IndexNodeClient) *MockWorkerManager_GetAllClients_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWorkerManager_GetAllClients_Call) RunAndReturn(run func() map[int64]types.IndexNodeClient) *MockWorkerManager_GetAllClients_Call { + _c.Call.Return(run) + return _c +} + +// GetClientByID provides a mock function with given fields: nodeID +func (_m *MockWorkerManager) GetClientByID(nodeID int64) (types.IndexNodeClient, bool) { + ret := _m.Called(nodeID) + + var r0 types.IndexNodeClient + var r1 bool + if rf, ok := ret.Get(0).(func(int64) (types.IndexNodeClient, bool)); ok { + return rf(nodeID) + } + if rf, ok := ret.Get(0).(func(int64) types.IndexNodeClient); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.IndexNodeClient) + } + } + + if rf, ok := ret.Get(1).(func(int64) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// MockWorkerManager_GetClientByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetClientByID' +type MockWorkerManager_GetClientByID_Call struct { + *mock.Call +} + +// GetClientByID is a helper method to define mock.On call +// - nodeID int64 +func (_e *MockWorkerManager_Expecter) GetClientByID(nodeID interface{}) *MockWorkerManager_GetClientByID_Call { + return &MockWorkerManager_GetClientByID_Call{Call: _e.mock.On("GetClientByID", nodeID)} +} + +func (_c *MockWorkerManager_GetClientByID_Call) Run(run func(nodeID int64)) *MockWorkerManager_GetClientByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockWorkerManager_GetClientByID_Call) Return(_a0 types.IndexNodeClient, _a1 bool) *MockWorkerManager_GetClientByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockWorkerManager_GetClientByID_Call) RunAndReturn(run func(int64) (types.IndexNodeClient, bool)) *MockWorkerManager_GetClientByID_Call { + _c.Call.Return(run) + return _c +} + +// PickClient provides a mock function with given fields: +func (_m *MockWorkerManager) PickClient() (int64, types.IndexNodeClient) { + ret := _m.Called() + + var r0 int64 + var r1 types.IndexNodeClient + if rf, ok := ret.Get(0).(func() (int64, types.IndexNodeClient)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func() types.IndexNodeClient); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(types.IndexNodeClient) + } + } + + return r0, r1 +} + +// MockWorkerManager_PickClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PickClient' +type MockWorkerManager_PickClient_Call struct { + *mock.Call +} + +// PickClient is a helper method to define mock.On call +func (_e *MockWorkerManager_Expecter) PickClient() *MockWorkerManager_PickClient_Call { + return &MockWorkerManager_PickClient_Call{Call: _e.mock.On("PickClient")} +} + +func (_c *MockWorkerManager_PickClient_Call) Run(run func()) *MockWorkerManager_PickClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWorkerManager_PickClient_Call) Return(_a0 int64, _a1 types.IndexNodeClient) *MockWorkerManager_PickClient_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockWorkerManager_PickClient_Call) RunAndReturn(run func() (int64, types.IndexNodeClient)) *MockWorkerManager_PickClient_Call { + _c.Call.Return(run) + return _c +} + +// RemoveNode provides a mock function with given fields: nodeID +func (_m *MockWorkerManager) RemoveNode(nodeID int64) { + _m.Called(nodeID) +} + +// MockWorkerManager_RemoveNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveNode' +type MockWorkerManager_RemoveNode_Call struct { + *mock.Call +} + +// RemoveNode is a helper method to define mock.On call +// - nodeID int64 +func (_e *MockWorkerManager_Expecter) RemoveNode(nodeID interface{}) *MockWorkerManager_RemoveNode_Call { + return &MockWorkerManager_RemoveNode_Call{Call: _e.mock.On("RemoveNode", nodeID)} +} + +func (_c *MockWorkerManager_RemoveNode_Call) Run(run func(nodeID int64)) *MockWorkerManager_RemoveNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockWorkerManager_RemoveNode_Call) Return() *MockWorkerManager_RemoveNode_Call { + _c.Call.Return() + return _c +} + +func (_c *MockWorkerManager_RemoveNode_Call) RunAndReturn(run func(int64)) *MockWorkerManager_RemoveNode_Call { + _c.Call.Return(run) + return _c +} + +// StoppingNode provides a mock function with given fields: nodeID +func (_m *MockWorkerManager) StoppingNode(nodeID int64) { + _m.Called(nodeID) +} + +// MockWorkerManager_StoppingNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StoppingNode' +type MockWorkerManager_StoppingNode_Call struct { + *mock.Call +} + +// StoppingNode is a helper method to define mock.On call +// - nodeID int64 +func (_e *MockWorkerManager_Expecter) StoppingNode(nodeID interface{}) *MockWorkerManager_StoppingNode_Call { + return &MockWorkerManager_StoppingNode_Call{Call: _e.mock.On("StoppingNode", nodeID)} +} + +func (_c *MockWorkerManager_StoppingNode_Call) Run(run func(nodeID int64)) *MockWorkerManager_StoppingNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockWorkerManager_StoppingNode_Call) Return() *MockWorkerManager_StoppingNode_Call { + _c.Call.Return() + return _c +} + +func (_c *MockWorkerManager_StoppingNode_Call) RunAndReturn(run func(int64)) *MockWorkerManager_StoppingNode_Call { + _c.Call.Return(run) + return _c +} + +// NewMockWorkerManager creates a new instance of MockWorkerManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockWorkerManager(t interface { + mock.TestingT + Cleanup(func()) +}) *MockWorkerManager { + mock := &MockWorkerManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index bcfe185618..944723e056 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -152,10 +152,11 @@ type Server struct { // indexCoord types.IndexCoord // segReferManager *SegmentReferenceManager - indexBuilder *indexBuilder indexNodeManager *IndexNodeManager indexEngineVersionManager IndexEngineVersionManager + taskScheduler *taskScheduler + // manage ways that data coord access other coord broker broker.Broker } @@ -373,6 +374,7 @@ func (s *Server) initDataCoord() error { } log.Info("init service discovery done") + s.initTaskScheduler(storageCli) if Params.DataCoordCfg.EnableCompaction.GetAsBool() { s.createCompactionHandler() s.createCompactionTrigger() @@ -385,7 +387,6 @@ func (s *Server) initDataCoord() error { log.Info("init segment manager done") s.initGarbageCollection(storageCli) - s.initIndexBuilder(storageCli) s.importMeta, err = NewImportMeta(s.meta.catalog) if err != nil { @@ -419,6 +420,7 @@ func (s *Server) Start() error { } func (s *Server) startDataCoord() { + s.taskScheduler.Start() if Params.DataCoordCfg.EnableCompaction.GetAsBool() { s.compactionHandler.start() s.compactionTrigger.start() @@ -689,9 +691,9 @@ func (s *Server) initMeta(chunkManager storage.ChunkManager) error { return retry.Do(s.ctx, reloadEtcdFn, retry.Attempts(connMetaMaxRetryTime)) } -func (s *Server) initIndexBuilder(manager storage.ChunkManager) { - if s.indexBuilder == nil { - s.indexBuilder = newIndexBuilder(s.ctx, s.meta, s.indexNodeManager, manager, s.indexEngineVersionManager, s.handler) +func (s *Server) initTaskScheduler(manager storage.ChunkManager) { + if s.taskScheduler == nil { + s.taskScheduler = newTaskScheduler(s.ctx, s.meta, s.indexNodeManager, manager, s.indexEngineVersionManager, s.handler) } } @@ -1115,7 +1117,7 @@ func (s *Server) Stop() error { } logutil.Logger(s.ctx).Info("datacoord compaction stopped") - s.indexBuilder.Stop() + s.taskScheduler.Stop() logutil.Logger(s.ctx).Info("datacoord index builder stopped") s.cluster.Close() diff --git a/internal/datacoord/task_analyze.go b/internal/datacoord/task_analyze.go new file mode 100644 index 0000000000..a7a6d36f18 --- /dev/null +++ b/internal/datacoord/task_analyze.go @@ -0,0 +1,286 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "fmt" + "math" + + "github.com/samber/lo" + "go.uber.org/zap" + "golang.org/x/exp/slices" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type analyzeTask struct { + taskID int64 + nodeID int64 + taskInfo *indexpb.AnalyzeResult +} + +func (at *analyzeTask) GetTaskID() int64 { + return at.taskID +} + +func (at *analyzeTask) GetNodeID() int64 { + return at.nodeID +} + +func (at *analyzeTask) ResetNodeID() { + at.nodeID = 0 +} + +func (at *analyzeTask) CheckTaskHealthy(mt *meta) bool { + t := mt.analyzeMeta.GetTask(at.GetTaskID()) + return t != nil +} + +func (at *analyzeTask) SetState(state indexpb.JobState, failReason string) { + at.taskInfo.State = state + at.taskInfo.FailReason = failReason +} + +func (at *analyzeTask) GetState() indexpb.JobState { + return at.taskInfo.GetState() +} + +func (at *analyzeTask) GetFailReason() string { + return at.taskInfo.GetFailReason() +} + +func (at *analyzeTask) UpdateVersion(ctx context.Context, meta *meta) error { + return meta.analyzeMeta.UpdateVersion(at.GetTaskID()) +} + +func (at *analyzeTask) UpdateMetaBuildingState(nodeID int64, meta *meta) error { + if err := meta.analyzeMeta.BuildingTask(at.GetTaskID(), nodeID); err != nil { + return err + } + at.nodeID = nodeID + return nil +} + +func (at *analyzeTask) AssignTask(ctx context.Context, client types.IndexNodeClient, dependency *taskScheduler) (bool, bool) { + t := dependency.meta.analyzeMeta.GetTask(at.GetTaskID()) + if t == nil { + log.Ctx(ctx).Info("task is nil, delete it", zap.Int64("taskID", at.GetTaskID())) + at.SetState(indexpb.JobState_JobStateNone, "analyze task is nil") + return false, false + } + + var storageConfig *indexpb.StorageConfig + if Params.CommonCfg.StorageType.GetValue() == "local" { + storageConfig = &indexpb.StorageConfig{ + RootPath: Params.LocalStorageCfg.Path.GetValue(), + StorageType: Params.CommonCfg.StorageType.GetValue(), + } + } else { + storageConfig = &indexpb.StorageConfig{ + Address: Params.MinioCfg.Address.GetValue(), + AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(), + SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(), + UseSSL: Params.MinioCfg.UseSSL.GetAsBool(), + BucketName: Params.MinioCfg.BucketName.GetValue(), + RootPath: Params.MinioCfg.RootPath.GetValue(), + UseIAM: Params.MinioCfg.UseIAM.GetAsBool(), + IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(), + StorageType: Params.CommonCfg.StorageType.GetValue(), + Region: Params.MinioCfg.Region.GetValue(), + UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(), + CloudProvider: Params.MinioCfg.CloudProvider.GetValue(), + RequestTimeoutMs: Params.MinioCfg.RequestTimeoutMs.GetAsInt64(), + } + } + req := &indexpb.AnalyzeRequest{ + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + TaskID: at.GetTaskID(), + CollectionID: t.CollectionID, + PartitionID: t.PartitionID, + FieldID: t.FieldID, + FieldName: t.FieldName, + FieldType: t.FieldType, + Dim: t.Dim, + SegmentStats: make(map[int64]*indexpb.SegmentStats), + Version: t.Version, + StorageConfig: storageConfig, + } + + // When data analyze occurs, segments must not be discarded. Such as compaction, GC, etc. + segments := dependency.meta.SelectSegments(SegmentFilterFunc(func(info *SegmentInfo) bool { + return isSegmentHealthy(info) && slices.Contains(t.SegmentIDs, info.ID) + })) + segmentsMap := lo.SliceToMap(segments, func(t *SegmentInfo) (int64, *SegmentInfo) { + return t.ID, t + }) + + totalSegmentsRows := int64(0) + for _, segID := range t.SegmentIDs { + info := segmentsMap[segID] + if info == nil { + log.Ctx(ctx).Warn("analyze stats task is processing, but segment is nil, delete the task", + zap.Int64("taskID", at.GetTaskID()), zap.Int64("segmentID", segID)) + at.SetState(indexpb.JobState_JobStateFailed, fmt.Sprintf("segmentInfo with ID: %d is nil", segID)) + return false, false + } + + totalSegmentsRows += info.GetNumOfRows() + // get binlogIDs + binlogIDs := getBinLogIDs(info, t.FieldID) + req.SegmentStats[segID] = &indexpb.SegmentStats{ + ID: segID, + NumRows: info.GetNumOfRows(), + LogIDs: binlogIDs, + } + } + + collInfo, err := dependency.handler.GetCollection(ctx, segments[0].GetCollectionID()) + if err != nil { + log.Ctx(ctx).Info("analyze task get collection info failed", zap.Int64("collectionID", + segments[0].GetCollectionID()), zap.Error(err)) + at.SetState(indexpb.JobState_JobStateInit, err.Error()) + return false, false + } + + schema := collInfo.Schema + var field *schemapb.FieldSchema + + for _, f := range schema.Fields { + if f.FieldID == t.FieldID { + field = f + break + } + } + dim, err := storage.GetDimFromParams(field.TypeParams) + if err != nil { + at.SetState(indexpb.JobState_JobStateInit, err.Error()) + return false, false + } + req.Dim = int64(dim) + + totalSegmentsRawDataSize := float64(totalSegmentsRows) * float64(dim) * typeutil.VectorTypeSize(t.FieldType) // Byte + numClusters := int64(math.Ceil(totalSegmentsRawDataSize / float64(Params.DataCoordCfg.ClusteringCompactionPreferSegmentSize.GetAsSize()))) + if numClusters < Params.DataCoordCfg.ClusteringCompactionMinCentroidsNum.GetAsInt64() { + log.Ctx(ctx).Info("data size is too small, skip analyze task", zap.Float64("raw data size", totalSegmentsRawDataSize), zap.Int64("num clusters", numClusters), zap.Int64("minimum num clusters required", Params.DataCoordCfg.ClusteringCompactionMinCentroidsNum.GetAsInt64())) + at.SetState(indexpb.JobState_JobStateFinished, "") + return true, true + } + if numClusters > Params.DataCoordCfg.ClusteringCompactionMaxCentroidsNum.GetAsInt64() { + numClusters = Params.DataCoordCfg.ClusteringCompactionMaxCentroidsNum.GetAsInt64() + } + req.NumClusters = numClusters + req.MaxTrainSizeRatio = Params.DataCoordCfg.ClusteringCompactionMaxTrainSizeRatio.GetAsFloat() // control clustering train data size + // config to detect data skewness + req.MinClusterSizeRatio = Params.DataCoordCfg.ClusteringCompactionMinClusterSizeRatio.GetAsFloat() + req.MaxClusterSizeRatio = Params.DataCoordCfg.ClusteringCompactionMaxClusterSizeRatio.GetAsFloat() + req.MaxClusterSize = Params.DataCoordCfg.ClusteringCompactionMaxClusterSize.GetAsSize() + + ctx, cancel := context.WithTimeout(context.Background(), reqTimeoutInterval) + defer cancel() + resp, err := client.CreateJobV2(ctx, &indexpb.CreateJobV2Request{ + ClusterID: req.GetClusterID(), + TaskID: req.GetTaskID(), + JobType: indexpb.JobType_JobTypeAnalyzeJob, + Request: &indexpb.CreateJobV2Request_AnalyzeRequest{ + AnalyzeRequest: req, + }, + }) + if err == nil { + err = merr.Error(resp) + } + if err != nil { + log.Ctx(ctx).Warn("assign analyze task to indexNode failed", zap.Int64("taskID", at.GetTaskID()), zap.Error(err)) + at.SetState(indexpb.JobState_JobStateRetry, err.Error()) + return false, true + } + + log.Ctx(ctx).Info("analyze task assigned successfully", zap.Int64("taskID", at.GetTaskID())) + at.SetState(indexpb.JobState_JobStateInProgress, "") + return true, false +} + +func (at *analyzeTask) setResult(result *indexpb.AnalyzeResult) { + at.taskInfo = result +} + +func (at *analyzeTask) QueryResult(ctx context.Context, client types.IndexNodeClient) { + resp, err := client.QueryJobsV2(ctx, &indexpb.QueryJobsV2Request{ + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + TaskIDs: []int64{at.GetTaskID()}, + JobType: indexpb.JobType_JobTypeAnalyzeJob, + }) + if err == nil { + err = merr.Error(resp.GetStatus()) + } + if err != nil { + log.Ctx(ctx).Warn("query analysis task result from IndexNode fail", zap.Int64("nodeID", at.GetNodeID()), + zap.Error(err)) + at.SetState(indexpb.JobState_JobStateRetry, err.Error()) + return + } + + // infos length is always one. + for _, result := range resp.GetAnalyzeJobResults().GetResults() { + if result.GetTaskID() == at.GetTaskID() { + log.Ctx(ctx).Info("query analysis task info successfully", + zap.Int64("taskID", at.GetTaskID()), zap.String("result state", result.GetState().String()), + zap.String("failReason", result.GetFailReason())) + if result.GetState() == indexpb.JobState_JobStateFinished || result.GetState() == indexpb.JobState_JobStateFailed || + result.GetState() == indexpb.JobState_JobStateRetry { + // state is retry or finished or failed + at.setResult(result) + } else if result.GetState() == indexpb.JobState_JobStateNone { + at.SetState(indexpb.JobState_JobStateRetry, "analyze task state is none in info response") + } + // inProgress or unissued/init, keep InProgress state + return + } + } + log.Ctx(ctx).Warn("query analyze task info failed, indexNode does not have task info", + zap.Int64("taskID", at.GetTaskID())) + at.SetState(indexpb.JobState_JobStateRetry, "analyze result is not in info response") +} + +func (at *analyzeTask) DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool { + resp, err := client.DropJobsV2(ctx, &indexpb.DropJobsV2Request{ + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + TaskIDs: []UniqueID{at.GetTaskID()}, + JobType: indexpb.JobType_JobTypeAnalyzeJob, + }) + if err == nil { + err = merr.Error(resp) + } + if err != nil { + log.Ctx(ctx).Warn("notify worker drop the analysis task fail", zap.Int64("taskID", at.GetTaskID()), + zap.Int64("nodeID", at.GetNodeID()), zap.Error(err)) + return false + } + log.Ctx(ctx).Info("drop analyze on worker success", + zap.Int64("taskID", at.GetTaskID()), zap.Int64("nodeID", at.GetNodeID())) + return true +} + +func (at *analyzeTask) SetJobInfo(meta *meta) error { + return meta.analyzeMeta.FinishTask(at.GetTaskID(), at.taskInfo) +} diff --git a/internal/datacoord/task_index.go b/internal/datacoord/task_index.go new file mode 100644 index 0000000000..a702934bc4 --- /dev/null +++ b/internal/datacoord/task_index.go @@ -0,0 +1,337 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "path" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/querycoordv2/params" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/types" + itypeutil "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/indexparams" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type indexBuildTask struct { + buildID int64 + nodeID int64 + taskInfo *indexpb.IndexTaskInfo +} + +func (it *indexBuildTask) GetTaskID() int64 { + return it.buildID +} + +func (it *indexBuildTask) GetNodeID() int64 { + return it.nodeID +} + +func (it *indexBuildTask) ResetNodeID() { + it.nodeID = 0 +} + +func (it *indexBuildTask) CheckTaskHealthy(mt *meta) bool { + _, exist := mt.indexMeta.GetIndexJob(it.GetTaskID()) + return exist +} + +func (it *indexBuildTask) SetState(state indexpb.JobState, failReason string) { + it.taskInfo.State = commonpb.IndexState(state) + it.taskInfo.FailReason = failReason +} + +func (it *indexBuildTask) GetState() indexpb.JobState { + return indexpb.JobState(it.taskInfo.GetState()) +} + +func (it *indexBuildTask) GetFailReason() string { + return it.taskInfo.FailReason +} + +func (it *indexBuildTask) UpdateVersion(ctx context.Context, meta *meta) error { + return meta.indexMeta.UpdateVersion(it.buildID) +} + +func (it *indexBuildTask) UpdateMetaBuildingState(nodeID int64, meta *meta) error { + it.nodeID = nodeID + return meta.indexMeta.BuildIndex(it.buildID, nodeID) +} + +func (it *indexBuildTask) AssignTask(ctx context.Context, client types.IndexNodeClient, dependency *taskScheduler) (bool, bool) { + segIndex, exist := dependency.meta.indexMeta.GetIndexJob(it.buildID) + if !exist || segIndex == nil { + log.Ctx(ctx).Info("index task has not exist in meta table, remove task", zap.Int64("buildID", it.buildID)) + it.SetState(indexpb.JobState_JobStateNone, "index task has not exist in meta table") + return false, false + } + + segment := dependency.meta.GetSegment(segIndex.SegmentID) + if !isSegmentHealthy(segment) || !dependency.meta.indexMeta.IsIndexExist(segIndex.CollectionID, segIndex.IndexID) { + log.Ctx(ctx).Info("task is no need to build index, remove it", zap.Int64("buildID", it.buildID)) + it.SetState(indexpb.JobState_JobStateNone, "task is no need to build index") + return false, false + } + indexParams := dependency.meta.indexMeta.GetIndexParams(segIndex.CollectionID, segIndex.IndexID) + indexType := GetIndexType(indexParams) + if isFlatIndex(indexType) || segIndex.NumRows < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() { + log.Ctx(ctx).Info("segment does not need index really", zap.Int64("buildID", it.buildID), + zap.Int64("segmentID", segIndex.SegmentID), zap.Int64("num rows", segIndex.NumRows)) + it.SetState(indexpb.JobState_JobStateFinished, "fake finished index success") + return true, true + } + // vector index build needs information of optional scalar fields data + optionalFields := make([]*indexpb.OptionalFieldInfo, 0) + if Params.CommonCfg.EnableMaterializedView.GetAsBool() && isOptionalScalarFieldSupported(indexType) { + collInfo, err := dependency.handler.GetCollection(ctx, segIndex.CollectionID) + if err != nil || collInfo == nil { + log.Ctx(ctx).Warn("get collection failed", zap.Int64("collID", segIndex.CollectionID), zap.Error(err)) + it.SetState(indexpb.JobState_JobStateInit, err.Error()) + return false, false + } + colSchema := collInfo.Schema + partitionKeyField, err := typeutil.GetPartitionKeyFieldSchema(colSchema) + if partitionKeyField == nil || err != nil { + log.Ctx(ctx).Warn("index builder get partition key field failed", zap.Int64("buildID", it.buildID), zap.Error(err)) + } else { + optionalFields = append(optionalFields, &indexpb.OptionalFieldInfo{ + FieldID: partitionKeyField.FieldID, + FieldName: partitionKeyField.Name, + FieldType: int32(partitionKeyField.DataType), + DataIds: getBinLogIDs(segment, partitionKeyField.FieldID), + }) + } + } + + typeParams := dependency.meta.indexMeta.GetTypeParams(segIndex.CollectionID, segIndex.IndexID) + + var storageConfig *indexpb.StorageConfig + if Params.CommonCfg.StorageType.GetValue() == "local" { + storageConfig = &indexpb.StorageConfig{ + RootPath: Params.LocalStorageCfg.Path.GetValue(), + StorageType: Params.CommonCfg.StorageType.GetValue(), + } + } else { + storageConfig = &indexpb.StorageConfig{ + Address: Params.MinioCfg.Address.GetValue(), + AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(), + SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(), + UseSSL: Params.MinioCfg.UseSSL.GetAsBool(), + SslCACert: Params.MinioCfg.SslCACert.GetValue(), + BucketName: Params.MinioCfg.BucketName.GetValue(), + RootPath: Params.MinioCfg.RootPath.GetValue(), + UseIAM: Params.MinioCfg.UseIAM.GetAsBool(), + IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(), + StorageType: Params.CommonCfg.StorageType.GetValue(), + Region: Params.MinioCfg.Region.GetValue(), + UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(), + CloudProvider: Params.MinioCfg.CloudProvider.GetValue(), + RequestTimeoutMs: Params.MinioCfg.RequestTimeoutMs.GetAsInt64(), + } + } + + fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID) + binlogIDs := getBinLogIDs(segment, fieldID) + if isDiskANNIndex(GetIndexType(indexParams)) { + var err error + indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams) + if err != nil { + log.Ctx(ctx).Warn("failed to append index build params", zap.Int64("buildID", it.buildID), zap.Error(err)) + it.SetState(indexpb.JobState_JobStateInit, err.Error()) + return false, false + } + } + var req *indexpb.CreateJobRequest + collectionInfo, err := dependency.handler.GetCollection(ctx, segment.GetCollectionID()) + if err != nil { + log.Ctx(ctx).Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err)) + return false, false + } + + schema := collectionInfo.Schema + var field *schemapb.FieldSchema + + for _, f := range schema.Fields { + if f.FieldID == fieldID { + field = f + break + } + } + + dim, err := storage.GetDimFromParams(field.TypeParams) + if err != nil { + log.Ctx(ctx).Warn("failed to get dim from field type params", + zap.String("field type", field.GetDataType().String()), zap.Error(err)) + // don't return, maybe field is scalar field or sparseFloatVector + } + + if Params.CommonCfg.EnableStorageV2.GetAsBool() { + storePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue(), segment.GetID()) + if err != nil { + log.Ctx(ctx).Warn("failed to get storage uri", zap.Error(err)) + it.SetState(indexpb.JobState_JobStateInit, err.Error()) + return false, false + } + indexStorePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue()+"/index", segment.GetID()) + if err != nil { + log.Ctx(ctx).Warn("failed to get storage uri", zap.Error(err)) + it.SetState(indexpb.JobState_JobStateInit, err.Error()) + return false, false + } + + req = &indexpb.CreateJobRequest{ + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + IndexFilePrefix: path.Join(dependency.chunkManager.RootPath(), common.SegmentIndexPath), + BuildID: it.buildID, + IndexVersion: segIndex.IndexVersion, + StorageConfig: storageConfig, + IndexParams: indexParams, + TypeParams: typeParams, + NumRows: segIndex.NumRows, + CollectionID: segment.GetCollectionID(), + PartitionID: segment.GetPartitionID(), + SegmentID: segment.GetID(), + FieldID: fieldID, + FieldName: field.Name, + FieldType: field.DataType, + StorePath: storePath, + StoreVersion: segment.GetStorageVersion(), + IndexStorePath: indexStorePath, + Dim: int64(dim), + CurrentIndexVersion: dependency.indexEngineVersionManager.GetCurrentIndexEngineVersion(), + DataIds: binlogIDs, + OptionalScalarFields: optionalFields, + Field: field, + } + } else { + req = &indexpb.CreateJobRequest{ + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + IndexFilePrefix: path.Join(dependency.chunkManager.RootPath(), common.SegmentIndexPath), + BuildID: it.buildID, + IndexVersion: segIndex.IndexVersion, + StorageConfig: storageConfig, + IndexParams: indexParams, + TypeParams: typeParams, + NumRows: segIndex.NumRows, + CurrentIndexVersion: dependency.indexEngineVersionManager.GetCurrentIndexEngineVersion(), + DataIds: binlogIDs, + CollectionID: segment.GetCollectionID(), + PartitionID: segment.GetPartitionID(), + SegmentID: segment.GetID(), + FieldID: fieldID, + OptionalScalarFields: optionalFields, + Dim: int64(dim), + Field: field, + } + } + + ctx, cancel := context.WithTimeout(context.Background(), reqTimeoutInterval) + defer cancel() + resp, err := client.CreateJobV2(ctx, &indexpb.CreateJobV2Request{ + ClusterID: req.GetClusterID(), + TaskID: req.GetBuildID(), + JobType: indexpb.JobType_JobTypeIndexJob, + Request: &indexpb.CreateJobV2Request_IndexRequest{ + IndexRequest: req, + }, + }) + if err == nil { + err = merr.Error(resp) + } + if err != nil { + log.Ctx(ctx).Warn("assign index task to indexNode failed", zap.Int64("buildID", it.buildID), zap.Error(err)) + it.SetState(indexpb.JobState_JobStateRetry, err.Error()) + return false, true + } + + log.Ctx(ctx).Info("index task assigned successfully", zap.Int64("buildID", it.buildID), + zap.Int64("segmentID", segIndex.SegmentID)) + it.SetState(indexpb.JobState_JobStateInProgress, "") + return true, false +} + +func (it *indexBuildTask) setResult(info *indexpb.IndexTaskInfo) { + it.taskInfo = info +} + +func (it *indexBuildTask) QueryResult(ctx context.Context, node types.IndexNodeClient) { + resp, err := node.QueryJobsV2(ctx, &indexpb.QueryJobsV2Request{ + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + TaskIDs: []UniqueID{it.GetTaskID()}, + JobType: indexpb.JobType_JobTypeIndexJob, + }) + if err == nil { + err = merr.Error(resp.GetStatus()) + } + if err != nil { + log.Ctx(ctx).Warn("get jobs info from IndexNode failed", zap.Int64("buildID", it.GetTaskID()), + zap.Int64("nodeID", it.GetNodeID()), zap.Error(err)) + it.SetState(indexpb.JobState_JobStateRetry, err.Error()) + return + } + + // indexInfos length is always one. + for _, info := range resp.GetIndexJobResults().GetResults() { + if info.GetBuildID() == it.GetTaskID() { + log.Ctx(ctx).Info("query task index info successfully", + zap.Int64("taskID", it.GetTaskID()), zap.String("result state", info.GetState().String()), + zap.String("failReason", info.GetFailReason())) + if info.GetState() == commonpb.IndexState_Finished || info.GetState() == commonpb.IndexState_Failed || + info.GetState() == commonpb.IndexState_Retry { + // state is retry or finished or failed + it.setResult(info) + } else if info.GetState() == commonpb.IndexState_IndexStateNone { + it.SetState(indexpb.JobState_JobStateRetry, "index state is none in info response") + } + // inProgress or unissued, keep InProgress state + return + } + } + it.SetState(indexpb.JobState_JobStateRetry, "index is not in info response") +} + +func (it *indexBuildTask) DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool { + resp, err := client.DropJobsV2(ctx, &indexpb.DropJobsV2Request{ + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + TaskIDs: []UniqueID{it.GetTaskID()}, + JobType: indexpb.JobType_JobTypeIndexJob, + }) + if err == nil { + err = merr.Error(resp) + } + if err != nil { + log.Ctx(ctx).Warn("notify worker drop the index task fail", zap.Int64("taskID", it.GetTaskID()), + zap.Int64("nodeID", it.GetNodeID()), zap.Error(err)) + return false + } + log.Ctx(ctx).Info("drop index task on worker success", zap.Int64("taskID", it.GetTaskID()), + zap.Int64("nodeID", it.GetNodeID())) + return true +} + +func (it *indexBuildTask) SetJobInfo(meta *meta) error { + return meta.indexMeta.FinishTask(it.taskInfo) +} diff --git a/internal/datacoord/task_scheduler.go b/internal/datacoord/task_scheduler.go new file mode 100644 index 0000000000..6b07689551 --- /dev/null +++ b/internal/datacoord/task_scheduler.go @@ -0,0 +1,296 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" +) + +const ( + reqTimeoutInterval = time.Second * 10 +) + +type taskScheduler struct { + sync.RWMutex + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + scheduleDuration time.Duration + + // TODO @xiaocai2333: use priority queue + tasks map[int64]Task + notifyChan chan struct{} + + meta *meta + + policy buildIndexPolicy + nodeManager WorkerManager + chunkManager storage.ChunkManager + indexEngineVersionManager IndexEngineVersionManager + handler Handler +} + +func newTaskScheduler( + ctx context.Context, + metaTable *meta, nodeManager WorkerManager, + chunkManager storage.ChunkManager, + indexEngineVersionManager IndexEngineVersionManager, + handler Handler, +) *taskScheduler { + ctx, cancel := context.WithCancel(ctx) + + ts := &taskScheduler{ + ctx: ctx, + cancel: cancel, + meta: metaTable, + tasks: make(map[int64]Task), + notifyChan: make(chan struct{}, 1), + scheduleDuration: Params.DataCoordCfg.IndexTaskSchedulerInterval.GetAsDuration(time.Millisecond), + policy: defaultBuildIndexPolicy, + nodeManager: nodeManager, + chunkManager: chunkManager, + handler: handler, + indexEngineVersionManager: indexEngineVersionManager, + } + ts.reloadFromKV() + return ts +} + +func (s *taskScheduler) Start() { + s.wg.Add(1) + go s.schedule() +} + +func (s *taskScheduler) Stop() { + s.cancel() + s.wg.Wait() +} + +func (s *taskScheduler) reloadFromKV() { + segments := s.meta.GetAllSegmentsUnsafe() + for _, segment := range segments { + for _, segIndex := range s.meta.indexMeta.getSegmentIndexes(segment.ID) { + if segIndex.IsDeleted { + continue + } + if segIndex.IndexState != commonpb.IndexState_Finished && segIndex.IndexState != commonpb.IndexState_Failed { + s.tasks[segIndex.BuildID] = &indexBuildTask{ + buildID: segIndex.BuildID, + nodeID: segIndex.NodeID, + taskInfo: &indexpb.IndexTaskInfo{ + BuildID: segIndex.BuildID, + State: segIndex.IndexState, + FailReason: segIndex.FailReason, + }, + } + } + } + } + + allAnalyzeTasks := s.meta.analyzeMeta.GetAllTasks() + for taskID, t := range allAnalyzeTasks { + if t.State != indexpb.JobState_JobStateFinished && t.State != indexpb.JobState_JobStateFailed { + s.tasks[taskID] = &analyzeTask{ + taskID: taskID, + nodeID: t.NodeID, + taskInfo: &indexpb.AnalyzeResult{ + TaskID: taskID, + State: t.State, + FailReason: t.FailReason, + }, + } + } + } +} + +// notify is an unblocked notify function +func (s *taskScheduler) notify() { + select { + case s.notifyChan <- struct{}{}: + default: + } +} + +func (s *taskScheduler) enqueue(task Task) { + defer s.notify() + + s.Lock() + defer s.Unlock() + taskID := task.GetTaskID() + if _, ok := s.tasks[taskID]; !ok { + s.tasks[taskID] = task + } + log.Info("taskScheduler enqueue task", zap.Int64("taskID", taskID)) +} + +func (s *taskScheduler) schedule() { + // receive notifyChan + // time ticker + log.Ctx(s.ctx).Info("task scheduler loop start") + defer s.wg.Done() + ticker := time.NewTicker(s.scheduleDuration) + defer ticker.Stop() + for { + select { + case <-s.ctx.Done(): + log.Ctx(s.ctx).Warn("task scheduler ctx done") + return + case _, ok := <-s.notifyChan: + if ok { + s.run() + } + // !ok means indexBuild is closed. + case <-ticker.C: + s.run() + } + } +} + +func (s *taskScheduler) getTask(taskID UniqueID) Task { + s.RLock() + defer s.RUnlock() + + return s.tasks[taskID] +} + +func (s *taskScheduler) run() { + // schedule policy + s.RLock() + taskIDs := make([]UniqueID, 0, len(s.tasks)) + for tID := range s.tasks { + taskIDs = append(taskIDs, tID) + } + s.RUnlock() + if len(taskIDs) > 0 { + log.Ctx(s.ctx).Info("task scheduler", zap.Int("task num", len(taskIDs))) + } + + s.policy(taskIDs) + + for _, taskID := range taskIDs { + ok := s.process(taskID) + if !ok { + log.Ctx(s.ctx).Info("there is no idle indexing node, wait a minute...") + break + } + } +} + +func (s *taskScheduler) removeTask(taskID UniqueID) { + s.Lock() + defer s.Unlock() + delete(s.tasks, taskID) +} + +func (s *taskScheduler) process(taskID UniqueID) bool { + task := s.getTask(taskID) + + if !task.CheckTaskHealthy(s.meta) { + s.removeTask(taskID) + return true + } + state := task.GetState() + log.Ctx(s.ctx).Info("task is processing", zap.Int64("taskID", taskID), + zap.String("state", state.String())) + + switch state { + case indexpb.JobState_JobStateNone: + s.removeTask(taskID) + + case indexpb.JobState_JobStateInit: + // 1. pick an indexNode client + nodeID, client := s.nodeManager.PickClient() + if client == nil { + log.Ctx(s.ctx).Debug("pick client failed") + return false + } + log.Ctx(s.ctx).Info("pick client success", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID)) + + // 2. update version + if err := task.UpdateVersion(s.ctx, s.meta); err != nil { + log.Ctx(s.ctx).Warn("update task version failed", zap.Int64("taskID", taskID), zap.Error(err)) + return false + } + log.Ctx(s.ctx).Info("update task version success", zap.Int64("taskID", taskID)) + + // 3. assign task to indexNode + success, skip := task.AssignTask(s.ctx, client, s) + if !success { + log.Ctx(s.ctx).Warn("assign task to client failed", zap.Int64("taskID", taskID), + zap.String("new state", task.GetState().String()), zap.String("fail reason", task.GetFailReason())) + // If the problem is caused by the task itself, subsequent tasks will not be skipped. + // If etcd fails or fails to send tasks to the node, the subsequent tasks will be skipped. + return !skip + } + if skip { + // create index for small segment(<1024), skip next steps. + return true + } + log.Ctx(s.ctx).Info("assign task to client success", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID)) + + // 4. update meta state + if err := task.UpdateMetaBuildingState(nodeID, s.meta); err != nil { + log.Ctx(s.ctx).Warn("update meta building state failed", zap.Int64("taskID", taskID), zap.Error(err)) + task.SetState(indexpb.JobState_JobStateRetry, "update meta building state failed") + return false + } + log.Ctx(s.ctx).Info("update task meta state to InProgress success", zap.Int64("taskID", taskID), + zap.Int64("nodeID", nodeID)) + case indexpb.JobState_JobStateFinished, indexpb.JobState_JobStateFailed: + if err := task.SetJobInfo(s.meta); err != nil { + log.Ctx(s.ctx).Warn("update task info failed", zap.Error(err)) + return true + } + client, exist := s.nodeManager.GetClientByID(task.GetNodeID()) + if exist { + if !task.DropTaskOnWorker(s.ctx, client) { + return true + } + } + s.removeTask(taskID) + case indexpb.JobState_JobStateRetry: + client, exist := s.nodeManager.GetClientByID(task.GetNodeID()) + if exist { + if !task.DropTaskOnWorker(s.ctx, client) { + return true + } + } + task.SetState(indexpb.JobState_JobStateInit, "") + task.ResetNodeID() + + default: + // state: in_progress + client, exist := s.nodeManager.GetClientByID(task.GetNodeID()) + if exist { + task.QueryResult(s.ctx, client) + return true + } + task.SetState(indexpb.JobState_JobStateRetry, "") + } + return true +} diff --git a/internal/datacoord/task_scheduler_test.go b/internal/datacoord/task_scheduler_test.go new file mode 100644 index 0000000000..dd072e1a36 --- /dev/null +++ b/internal/datacoord/task_scheduler_test.go @@ -0,0 +1,1423 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/metastore" + catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks" + "github.com/milvus-io/milvus/internal/metastore/model" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +var ( + collID = UniqueID(100) + partID = UniqueID(200) + indexID = UniqueID(300) + fieldID = UniqueID(400) + indexName = "_default_idx" + segID = UniqueID(500) + buildID = UniqueID(600) + nodeID = UniqueID(700) +) + +func createIndexMeta(catalog metastore.DataCoordCatalog) *indexMeta { + return &indexMeta{ + catalog: catalog, + indexes: map[UniqueID]map[UniqueID]*model.Index{ + collID: { + indexID: { + TenantID: "", + CollectionID: collID, + FieldID: fieldID, + IndexID: indexID, + IndexName: indexName, + IsDeleted: false, + CreateTime: 1, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "128", + }, + }, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "HNSW", + }, + { + Key: common.MetricTypeKey, + Value: "L2", + }, + }, + }, + }, + }, + segmentIndexes: map[UniqueID]map[UniqueID]*model.SegmentIndex{ + segID: { + indexID: { + SegmentID: segID, + CollectionID: collID, + PartitionID: partID, + NumRows: 1025, + IndexID: indexID, + BuildID: buildID, + NodeID: 0, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 0, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 1: { + indexID: { + SegmentID: segID + 1, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 1, + NodeID: nodeID, + IndexVersion: 1, + IndexState: commonpb.IndexState_InProgress, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 2: { + indexID: { + SegmentID: segID + 2, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 2, + NodeID: nodeID, + IndexVersion: 1, + IndexState: commonpb.IndexState_InProgress, + FailReason: "", + IsDeleted: true, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 3: { + indexID: { + SegmentID: segID + 3, + CollectionID: collID, + PartitionID: partID, + NumRows: 500, + IndexID: indexID, + BuildID: buildID + 3, + NodeID: 0, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 4: { + indexID: { + SegmentID: segID + 4, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 4, + NodeID: nodeID, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 5: { + indexID: { + SegmentID: segID + 5, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 5, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 6: { + indexID: { + SegmentID: segID + 6, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 6, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 7: { + indexID: { + SegmentID: segID + 7, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 7, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Failed, + FailReason: "error", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 8: { + indexID: { + SegmentID: segID + 8, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 8, + NodeID: nodeID + 1, + IndexVersion: 1, + IndexState: commonpb.IndexState_InProgress, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 9: { + indexID: { + SegmentID: segID + 9, + CollectionID: collID, + PartitionID: partID, + NumRows: 500, + IndexID: indexID, + BuildID: buildID + 9, + NodeID: 0, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + segID + 10: { + indexID: { + SegmentID: segID + 10, + CollectionID: collID, + PartitionID: partID, + NumRows: 500, + IndexID: indexID, + BuildID: buildID + 10, + NodeID: nodeID, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + }, + buildID2SegmentIndex: map[UniqueID]*model.SegmentIndex{ + buildID: { + SegmentID: segID, + CollectionID: collID, + PartitionID: partID, + NumRows: 1025, + IndexID: indexID, + BuildID: buildID, + NodeID: 0, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 0, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 1: { + SegmentID: segID + 1, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 1, + NodeID: nodeID, + IndexVersion: 1, + IndexState: commonpb.IndexState_InProgress, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 2: { + SegmentID: segID + 2, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 2, + NodeID: nodeID, + IndexVersion: 1, + IndexState: commonpb.IndexState_InProgress, + FailReason: "", + IsDeleted: true, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 3: { + SegmentID: segID + 3, + CollectionID: collID, + PartitionID: partID, + NumRows: 500, + IndexID: indexID, + BuildID: buildID + 3, + NodeID: 0, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 4: { + SegmentID: segID + 4, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 4, + NodeID: nodeID, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 5: { + SegmentID: segID + 5, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 5, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 6: { + SegmentID: segID + 6, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 6, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 7: { + SegmentID: segID + 7, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 7, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Failed, + FailReason: "error", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 8: { + SegmentID: segID + 8, + CollectionID: collID, + PartitionID: partID, + NumRows: 1026, + IndexID: indexID, + BuildID: buildID + 8, + NodeID: nodeID + 1, + IndexVersion: 1, + IndexState: commonpb.IndexState_InProgress, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 9: { + SegmentID: segID + 9, + CollectionID: collID, + PartitionID: partID, + NumRows: 500, + IndexID: indexID, + BuildID: buildID + 9, + NodeID: 0, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + buildID + 10: { + SegmentID: segID + 10, + CollectionID: collID, + PartitionID: partID, + NumRows: 500, + IndexID: indexID, + BuildID: buildID + 10, + NodeID: nodeID, + IndexVersion: 0, + IndexState: commonpb.IndexState_Unissued, + FailReason: "", + IsDeleted: false, + CreateTime: 1111, + IndexFileKeys: nil, + IndexSize: 1, + }, + }, + } +} + +func createMeta(catalog metastore.DataCoordCatalog, am *analyzeMeta, im *indexMeta) *meta { + return &meta{ + catalog: catalog, + segments: &SegmentsInfo{ + segments: map[UniqueID]*SegmentInfo{ + 1000: { + SegmentInfo: &datapb.SegmentInfo{ + ID: 1000, + CollectionID: 10000, + PartitionID: 10001, + NumOfRows: 3000, + State: commonpb.SegmentState_Flushed, + Binlogs: []*datapb.FieldBinlog{{FieldID: 10002, Binlogs: []*datapb.Binlog{{LogID: 1}, {LogID: 2}, {LogID: 3}}}}, + }, + }, + 1001: { + SegmentInfo: &datapb.SegmentInfo{ + ID: 1001, + CollectionID: 10000, + PartitionID: 10001, + NumOfRows: 3000, + State: commonpb.SegmentState_Flushed, + Binlogs: []*datapb.FieldBinlog{{FieldID: 10002, Binlogs: []*datapb.Binlog{{LogID: 1}, {LogID: 2}, {LogID: 3}}}}, + }, + }, + 1002: { + SegmentInfo: &datapb.SegmentInfo{ + ID: 1002, + CollectionID: 10000, + PartitionID: 10001, + NumOfRows: 3000, + State: commonpb.SegmentState_Flushed, + Binlogs: []*datapb.FieldBinlog{{FieldID: 10002, Binlogs: []*datapb.Binlog{{LogID: 1}, {LogID: 2}, {LogID: 3}}}}, + }, + }, + segID: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1025, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 1: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 1, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1026, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 2: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 2, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1026, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 3: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 3, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 500, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 4: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 4, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1026, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 5: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 5, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1026, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 6: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 6, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1026, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 7: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 7, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1026, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 8: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 8, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 1026, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 9: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 9, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 500, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + segID + 10: { + SegmentInfo: &datapb.SegmentInfo{ + ID: segID + 10, + CollectionID: collID, + PartitionID: partID, + InsertChannel: "", + NumOfRows: 500, + State: commonpb.SegmentState_Flushed, + MaxRowNum: 65536, + LastExpireTime: 10, + }, + }, + }, + }, + analyzeMeta: am, + indexMeta: im, + } +} + +type taskSchedulerSuite struct { + suite.Suite + + collectionID int64 + partitionID int64 + fieldID int64 + segmentIDs []int64 + nodeID int64 + duration time.Duration +} + +func (s *taskSchedulerSuite) initParams() { + s.collectionID = collID + s.partitionID = partID + s.fieldID = fieldID + s.nodeID = nodeID + s.segmentIDs = []int64{1000, 1001, 1002} + s.duration = time.Millisecond * 100 +} + +func (s *taskSchedulerSuite) createAnalyzeMeta(catalog metastore.DataCoordCatalog) *analyzeMeta { + return &analyzeMeta{ + ctx: context.Background(), + catalog: catalog, + tasks: map[int64]*indexpb.AnalyzeTask{ + 1: { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 1, + State: indexpb.JobState_JobStateInit, + FieldType: schemapb.DataType_FloatVector, + }, + 2: { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 2, + NodeID: s.nodeID, + State: indexpb.JobState_JobStateInProgress, + FieldType: schemapb.DataType_FloatVector, + }, + 3: { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 3, + NodeID: s.nodeID, + State: indexpb.JobState_JobStateFinished, + FieldType: schemapb.DataType_FloatVector, + }, + 4: { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 4, + NodeID: s.nodeID, + State: indexpb.JobState_JobStateFailed, + FieldType: schemapb.DataType_FloatVector, + }, + 5: { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: []int64{1001, 1002}, + TaskID: 5, + NodeID: s.nodeID, + State: indexpb.JobState_JobStateRetry, + FieldType: schemapb.DataType_FloatVector, + }, + }, + } +} + +func (s *taskSchedulerSuite) SetupTest() { + paramtable.Init() + s.initParams() + Params.DataCoordCfg.ClusteringCompactionMinCentroidsNum.SwapTempValue("0") +} + +func (s *taskSchedulerSuite) TearDownSuite() { + Params.DataCoordCfg.ClusteringCompactionMinCentroidsNum.SwapTempValue("16") +} + +func (s *taskSchedulerSuite) scheduler(handler Handler) { + ctx := context.Background() + + catalog := catalogmocks.NewDataCoordCatalog(s.T()) + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil) + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil) + + in := mocks.NewMockIndexNodeClient(s.T()) + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil) + in.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, request *indexpb.QueryJobsV2Request, option ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) { + switch request.GetJobType() { + case indexpb.JobType_JobTypeIndexJob: + results := make([]*indexpb.IndexTaskInfo, 0) + for _, buildID := range request.GetTaskIDs() { + results = append(results, &indexpb.IndexTaskInfo{ + BuildID: buildID, + State: commonpb.IndexState_Finished, + IndexFileKeys: []string{"file1", "file2", "file3"}, + SerializedSize: 1024, + FailReason: "", + CurrentIndexVersion: 1, + IndexStoreVersion: 1, + }) + } + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: request.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_IndexJobResults{ + IndexJobResults: &indexpb.IndexJobResults{ + Results: results, + }, + }, + }, nil + case indexpb.JobType_JobTypeAnalyzeJob: + results := make([]*indexpb.AnalyzeResult, 0) + for _, taskID := range request.GetTaskIDs() { + results = append(results, &indexpb.AnalyzeResult{ + TaskID: taskID, + State: indexpb.JobState_JobStateFinished, + CentroidsFile: fmt.Sprintf("%d/stats_file", taskID), + FailReason: "", + }) + } + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: request.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{ + AnalyzeJobResults: &indexpb.AnalyzeResults{ + Results: results, + }, + }, + }, nil + default: + return &indexpb.QueryJobsV2Response{ + Status: merr.Status(errors.New("unknown job type")), + ClusterID: request.GetClusterID(), + }, nil + } + }) + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + workerManager := NewMockWorkerManager(s.T()) + workerManager.EXPECT().PickClient().Return(s.nodeID, in) + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true) + + mt := createMeta(catalog, s.createAnalyzeMeta(catalog), createIndexMeta(catalog)) + + cm := mocks.NewChunkManager(s.T()) + cm.EXPECT().RootPath().Return("root") + + scheduler := newTaskScheduler(ctx, mt, workerManager, cm, newIndexEngineVersionManager(), handler) + s.Equal(9, len(scheduler.tasks)) + s.Equal(indexpb.JobState_JobStateInit, scheduler.tasks[1].GetState()) + s.Equal(indexpb.JobState_JobStateInProgress, scheduler.tasks[2].GetState()) + s.Equal(indexpb.JobState_JobStateRetry, scheduler.tasks[5].GetState()) + s.Equal(indexpb.JobState_JobStateInit, scheduler.tasks[buildID].GetState()) + s.Equal(indexpb.JobState_JobStateInProgress, scheduler.tasks[buildID+1].GetState()) + s.Equal(indexpb.JobState_JobStateInit, scheduler.tasks[buildID+3].GetState()) + s.Equal(indexpb.JobState_JobStateInProgress, scheduler.tasks[buildID+8].GetState()) + s.Equal(indexpb.JobState_JobStateInit, scheduler.tasks[buildID+9].GetState()) + s.Equal(indexpb.JobState_JobStateInit, scheduler.tasks[buildID+10].GetState()) + + mt.segments.DropSegment(segID + 9) + + scheduler.scheduleDuration = time.Millisecond * 500 + scheduler.Start() + + s.Run("enqueue", func() { + taskID := int64(6) + newTask := &indexpb.AnalyzeTask{ + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: taskID, + } + err := scheduler.meta.analyzeMeta.AddAnalyzeTask(newTask) + s.NoError(err) + t := &analyzeTask{ + taskID: taskID, + taskInfo: &indexpb.AnalyzeResult{ + TaskID: taskID, + State: indexpb.JobState_JobStateInit, + FailReason: "", + }, + } + scheduler.enqueue(t) + }) + + for { + scheduler.RLock() + taskNum := len(scheduler.tasks) + scheduler.RUnlock() + + if taskNum == 0 { + break + } + time.Sleep(time.Second) + } + + scheduler.Stop() + + s.Equal(indexpb.JobState_JobStateFinished, mt.analyzeMeta.GetTask(1).GetState()) + s.Equal(indexpb.JobState_JobStateFinished, mt.analyzeMeta.GetTask(2).GetState()) + s.Equal(indexpb.JobState_JobStateFinished, mt.analyzeMeta.GetTask(3).GetState()) + s.Equal(indexpb.JobState_JobStateFailed, mt.analyzeMeta.GetTask(4).GetState()) + s.Equal(indexpb.JobState_JobStateFinished, mt.analyzeMeta.GetTask(5).GetState()) + s.Equal(indexpb.JobState_JobStateFinished, mt.analyzeMeta.GetTask(6).GetState()) + indexJob, exist := mt.indexMeta.GetIndexJob(buildID) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 1) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 2) + s.True(exist) + s.True(indexJob.IsDeleted) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 3) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 4) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 5) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 6) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 7) + s.True(exist) + s.Equal(commonpb.IndexState_Failed, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 8) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 9) + s.True(exist) + // segment not healthy, wait for GC + s.Equal(commonpb.IndexState_Unissued, indexJob.IndexState) + indexJob, exist = mt.indexMeta.GetIndexJob(buildID + 10) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) +} + +func (s *taskSchedulerSuite) Test_scheduler() { + handler := NewNMockHandler(s.T()) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", IsPrimaryKey: true, IsPartitionKey: true, DataType: schemapb.DataType_Int64}, + {FieldID: s.fieldID, Name: "vec", TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "10"}}}, + }, + }, + }, nil) + + s.Run("test scheduler with indexBuilderV1", func() { + paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("True") + defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("False") + s.scheduler(handler) + }) + + s.Run("test scheduler with indexBuilderV2", func() { + paramtable.Get().CommonCfg.EnableStorageV2.SwapTempValue("true") + defer paramtable.Get().CommonCfg.EnableStorageV2.SwapTempValue("false") + + s.scheduler(handler) + }) +} + +func (s *taskSchedulerSuite) Test_analyzeTaskFailCase() { + s.Run("segment info is nil", func() { + ctx := context.Background() + + catalog := catalogmocks.NewDataCoordCatalog(s.T()) + in := mocks.NewMockIndexNodeClient(s.T()) + workerManager := NewMockWorkerManager(s.T()) + + mt := createMeta(catalog, + &analyzeMeta{ + ctx: context.Background(), + catalog: catalog, + tasks: map[int64]*indexpb.AnalyzeTask{ + 1: { + CollectionID: s.collectionID, + PartitionID: s.partitionID, + FieldID: s.fieldID, + SegmentIDs: s.segmentIDs, + TaskID: 1, + State: indexpb.JobState_JobStateInit, + }, + }, + }, + &indexMeta{ + RWMutex: sync.RWMutex{}, + ctx: ctx, + catalog: catalog, + }) + + handler := NewNMockHandler(s.T()) + scheduler := newTaskScheduler(ctx, mt, workerManager, nil, nil, handler) + + mt.segments.DropSegment(1000) + scheduler.scheduleDuration = s.duration + scheduler.Start() + + // taskID 1 peek client success, update version success. AssignTask failed --> state: Failed --> save + workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Once() + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Once() + workerManager.EXPECT().GetClientByID(mock.Anything).Return(nil, false).Once() + + for { + scheduler.RLock() + taskNum := len(scheduler.tasks) + scheduler.RUnlock() + + if taskNum == 0 { + break + } + time.Sleep(time.Second) + } + + scheduler.Stop() + s.Equal(indexpb.JobState_JobStateFailed, mt.analyzeMeta.GetTask(1).GetState()) + }) + + s.Run("etcd save failed", func() { + ctx := context.Background() + + catalog := catalogmocks.NewDataCoordCatalog(s.T()) + catalog.EXPECT().DropAnalyzeTask(mock.Anything, mock.Anything).Return(nil) + + in := mocks.NewMockIndexNodeClient(s.T()) + + workerManager := NewMockWorkerManager(s.T()) + + mt := createMeta(catalog, s.createAnalyzeMeta(catalog), &indexMeta{ + RWMutex: sync.RWMutex{}, + ctx: ctx, + catalog: catalog, + }) + + handler := NewNMockHandler(s.T()) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: s.fieldID, + Name: "vec", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "10"}, + }, + }, + }, + }, + }, nil) + + scheduler := newTaskScheduler(ctx, mt, workerManager, nil, nil, handler) + + // remove task in meta + err := scheduler.meta.analyzeMeta.DropAnalyzeTask(1) + s.NoError(err) + err = scheduler.meta.analyzeMeta.DropAnalyzeTask(2) + s.NoError(err) + + mt.segments.DropSegment(1000) + scheduler.scheduleDuration = s.duration + scheduler.Start() + + // taskID 5 state retry, drop task on worker --> state: Init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // pick client fail --> state: init + workerManager.EXPECT().PickClient().Return(0, nil).Once() + + // update version failed --> state: init + workerManager.EXPECT().PickClient().Return(s.nodeID, in) + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(errors.New("catalog update version error")).Once() + + // assign task to indexNode fail --> state: retry + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Once() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(&commonpb.Status{ + Code: 65535, + Retriable: false, + Detail: "", + ExtraInfo: nil, + Reason: "mock error", + }, nil).Once() + + // drop task failed --> state: retry + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Status(errors.New("drop job failed")), nil).Once() + + // retry --> state: init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // update state to building failed --> state: retry + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Once() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(errors.New("catalog update building state error")).Once() + + // retry --> state: init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // assign success --> state: InProgress + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Twice() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // query result InProgress --> state: InProgress + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, request *indexpb.QueryJobsV2Request, option ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) { + results := make([]*indexpb.AnalyzeResult, 0) + for _, taskID := range request.GetTaskIDs() { + results = append(results, &indexpb.AnalyzeResult{ + TaskID: taskID, + State: indexpb.JobState_JobStateInProgress, + }) + } + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: request.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{ + AnalyzeJobResults: &indexpb.AnalyzeResults{ + Results: results, + }, + }, + }, nil + }).Once() + + // query result Retry --> state: retry + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, request *indexpb.QueryJobsV2Request, option ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) { + results := make([]*indexpb.AnalyzeResult, 0) + for _, taskID := range request.GetTaskIDs() { + results = append(results, &indexpb.AnalyzeResult{ + TaskID: taskID, + State: indexpb.JobState_JobStateRetry, + FailReason: "node analyze data failed", + }) + } + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: request.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{ + AnalyzeJobResults: &indexpb.AnalyzeResults{ + Results: results, + }, + }, + }, nil + }).Once() + + // retry --> state: init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // init --> state: InProgress + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Twice() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // query result failed --> state: retry + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).Return(&indexpb.QueryJobsV2Response{ + Status: merr.Status(errors.New("query job failed")), + }, nil).Once() + + // retry --> state: init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // init --> state: InProgress + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Twice() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // query result not exists --> state: retry + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).Return(&indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: "", + Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{}, + }, nil).Once() + + // retry --> state: init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // init --> state: InProgress + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Twice() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // node not exist --> state: retry + workerManager.EXPECT().GetClientByID(mock.Anything).Return(nil, false).Once() + + // retry --> state: init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // init --> state: InProgress + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Twice() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + // query result success --> state: finished + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, request *indexpb.QueryJobsV2Request, option ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) { + results := make([]*indexpb.AnalyzeResult, 0) + for _, taskID := range request.GetTaskIDs() { + results = append(results, &indexpb.AnalyzeResult{ + TaskID: taskID, + State: indexpb.JobState_JobStateFinished, + //CentroidsFile: fmt.Sprintf("%d/stats_file", taskID), + //SegmentOffsetMappingFiles: map[int64]string{ + // 1000: "1000/offset_mapping", + // 1001: "1001/offset_mapping", + // 1002: "1002/offset_mapping", + //}, + FailReason: "", + }) + } + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: request.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{ + AnalyzeJobResults: &indexpb.AnalyzeResults{ + Results: results, + }, + }, + }, nil + }).Once() + // set job info failed --> state: Finished + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(errors.New("set job info failed")).Once() + + // set job success, drop job on task failed --> state: Finished + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Once() + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Status(errors.New("drop job failed")), nil).Once() + + // drop job success --> no task + catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil).Once() + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once() + + for { + scheduler.RLock() + taskNum := len(scheduler.tasks) + scheduler.RUnlock() + + if taskNum == 0 { + break + } + time.Sleep(time.Second) + } + + scheduler.Stop() + }) +} + +func (s *taskSchedulerSuite) Test_indexTaskFailCase() { + s.Run("HNSW", func() { + ctx := context.Background() + + catalog := catalogmocks.NewDataCoordCatalog(s.T()) + in := mocks.NewMockIndexNodeClient(s.T()) + workerManager := NewMockWorkerManager(s.T()) + + mt := createMeta(catalog, + &analyzeMeta{ + ctx: context.Background(), + catalog: catalog, + }, + &indexMeta{ + RWMutex: sync.RWMutex{}, + ctx: ctx, + catalog: catalog, + indexes: map[UniqueID]map[UniqueID]*model.Index{ + s.collectionID: { + indexID: { + CollectionID: s.collectionID, + FieldID: s.fieldID, + IndexID: indexID, + IndexName: indexName, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "128", + }, + }, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "HNSW", + }, + { + Key: common.MetricTypeKey, + Value: "L2", + }, + }, + }, + }, + }, + buildID2SegmentIndex: map[UniqueID]*model.SegmentIndex{ + buildID: { + SegmentID: segID, + CollectionID: s.collectionID, + PartitionID: s.partitionID, + NumRows: 1025, + IndexID: indexID, + BuildID: buildID, + IndexState: commonpb.IndexState_Unissued, + }, + }, + segmentIndexes: map[UniqueID]map[UniqueID]*model.SegmentIndex{ + segID: { + buildID: { + SegmentID: segID, + CollectionID: s.collectionID, + PartitionID: s.partitionID, + NumRows: 1025, + IndexID: indexID, + BuildID: buildID, + IndexState: commonpb.IndexState_Unissued, + }, + }, + }, + }) + + cm := mocks.NewChunkManager(s.T()) + cm.EXPECT().RootPath().Return("ut-index") + + handler := NewNMockHandler(s.T()) + scheduler := newTaskScheduler(ctx, mt, workerManager, cm, newIndexEngineVersionManager(), handler) + + paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("True") + defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("False") + err := Params.Save("common.storage.scheme", "fake") + defer Params.Reset("common.storage.scheme") + Params.CommonCfg.EnableStorageV2.SwapTempValue("True") + defer Params.CommonCfg.EnableStorageV2.SwapTempValue("False") + scheduler.Start() + + // peek client success, update version success, get collection info failed --> init + workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() + + // peek client success, update version success, partition key field is nil, get collection info failed --> init + workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: s.fieldID, Name: "vec", TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "10"}}}, + }, + }, + }, nil).Once() + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() + + // peek client success, update version success, get collection info success, get dim failed --> init + workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", IsPrimaryKey: true, IsPartitionKey: true, DataType: schemapb.DataType_Int64}, + {FieldID: s.fieldID, Name: "vec"}, + }, + }, + }, nil).Twice() + + // peek client success, update version success, get collection info success, get dim success, get storage uri failed --> init + s.NoError(err) + workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i int64) (*collectionInfo, error) { + return &collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", IsPrimaryKey: true, IsPartitionKey: true, DataType: schemapb.DataType_Int64}, + {FieldID: s.fieldID, Name: "vec", TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "10"}}}, + }, + }, + }, nil + }).Twice() + s.NoError(err) + + // assign failed --> retry + workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, i int64) (*collectionInfo, error) { + Params.Reset("common.storage.scheme") + return &collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", IsPrimaryKey: true, IsPartitionKey: true, DataType: schemapb.DataType_Int64}, + {FieldID: s.fieldID, Name: "vec", TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "10"}}}, + }, + }, + }, nil + }).Twice() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() + + // retry --> init + workerManager.EXPECT().GetClientByID(mock.Anything).Return(nil, false).Once() + + // init --> inProgress + workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Twice() + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", IsPrimaryKey: true, IsPartitionKey: true, DataType: schemapb.DataType_Int64}, + {FieldID: s.fieldID, Name: "vec", TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "10"}}}, + }, + }, + }, nil).Twice() + in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil).Once() + + // inProgress --> Finished + workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() + in.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).Return(&indexpb.QueryJobsV2Response{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + ClusterID: "", + Result: &indexpb.QueryJobsV2Response_IndexJobResults{ + IndexJobResults: &indexpb.IndexJobResults{ + Results: []*indexpb.IndexTaskInfo{ + { + BuildID: buildID, + State: commonpb.IndexState_Finished, + IndexFileKeys: []string{"file1", "file2"}, + SerializedSize: 1024, + }, + }, + }, + }, + }, nil) + + // finished --> done + catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() + workerManager.EXPECT().GetClientByID(mock.Anything).Return(nil, false).Once() + + for { + scheduler.RLock() + taskNum := len(scheduler.tasks) + scheduler.RUnlock() + + if taskNum == 0 { + break + } + time.Sleep(time.Second) + } + + scheduler.Stop() + + indexJob, exist := mt.indexMeta.GetIndexJob(buildID) + s.True(exist) + s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) + }) +} + +func Test_taskSchedulerSuite(t *testing.T) { + suite.Run(t, new(taskSchedulerSuite)) +} diff --git a/internal/datacoord/types.go b/internal/datacoord/types.go new file mode 100644 index 0000000000..c6fe60365c --- /dev/null +++ b/internal/datacoord/types.go @@ -0,0 +1,40 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datacoord + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/types" +) + +type Task interface { + GetTaskID() int64 + GetNodeID() int64 + ResetNodeID() + CheckTaskHealthy(mt *meta) bool + SetState(state indexpb.JobState, failReason string) + GetState() indexpb.JobState + GetFailReason() string + UpdateVersion(ctx context.Context, meta *meta) error + UpdateMetaBuildingState(nodeID int64, meta *meta) error + AssignTask(ctx context.Context, client types.IndexNodeClient, dependency *taskScheduler) (bool, bool) + QueryResult(ctx context.Context, client types.IndexNodeClient) + DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool + SetJobInfo(meta *meta) error +} diff --git a/internal/datacoord/util.go b/internal/datacoord/util.go index 158b5e68c5..2a5c6c6279 100644 --- a/internal/datacoord/util.go +++ b/internal/datacoord/util.go @@ -196,6 +196,10 @@ func isFlatIndex(indexType string) bool { return indexType == indexparamcheck.IndexFaissIDMap || indexType == indexparamcheck.IndexFaissBinIDMap } +func isOptionalScalarFieldSupported(indexType string) bool { + return indexType == indexparamcheck.IndexHNSW +} + func isDiskANNIndex(indexType string) bool { return indexType == indexparamcheck.IndexDISKANN } @@ -256,3 +260,16 @@ func getCompactionMergeInfo(task *datapb.CompactionTask) *milvuspb.CompactionMer Target: target, } } + +func getBinLogIDs(segment *SegmentInfo, fieldID int64) []int64 { + binlogIDs := make([]int64, 0) + for _, fieldBinLog := range segment.GetBinlogs() { + if fieldBinLog.GetFieldID() == fieldID { + for _, binLog := range fieldBinLog.GetBinlogs() { + binlogIDs = append(binlogIDs, binLog.GetLogID()) + } + break + } + } + return binlogIDs +} diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index 26505dd96f..df44f9ee59 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -163,3 +163,21 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest return client.GetMetrics(ctx, req) }) } + +func (c *Client) CreateJobV2(ctx context.Context, req *indexpb.CreateJobV2Request, opts ...grpc.CallOption) (*commonpb.Status, error) { + return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*commonpb.Status, error) { + return client.CreateJobV2(ctx, req) + }) +} + +func (c *Client) QueryJobsV2(ctx context.Context, req *indexpb.QueryJobsV2Request, opts ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) { + return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*indexpb.QueryJobsV2Response, error) { + return client.QueryJobsV2(ctx, req) + }) +} + +func (c *Client) DropJobsV2(ctx context.Context, req *indexpb.DropJobsV2Request, opt ...grpc.CallOption) (*commonpb.Status, error) { + return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*commonpb.Status, error) { + return client.DropJobsV2(ctx, req) + }) +} diff --git a/internal/distributed/indexnode/client/client_test.go b/internal/distributed/indexnode/client/client_test.go index b88f7a9664..7b8227d052 100644 --- a/internal/distributed/indexnode/client/client_test.go +++ b/internal/distributed/indexnode/client/client_test.go @@ -164,6 +164,24 @@ func TestIndexNodeClient(t *testing.T) { assert.NoError(t, err) }) + t.Run("CreateJobV2", func(t *testing.T) { + req := &indexpb.CreateJobV2Request{} + _, err := inc.CreateJobV2(ctx, req) + assert.NoError(t, err) + }) + + t.Run("QueryJobsV2", func(t *testing.T) { + req := &indexpb.QueryJobsV2Request{} + _, err := inc.QueryJobsV2(ctx, req) + assert.NoError(t, err) + }) + + t.Run("DropJobsV2", func(t *testing.T) { + req := &indexpb.DropJobsV2Request{} + _, err := inc.DropJobsV2(ctx, req) + assert.NoError(t, err) + }) + err := inc.Close() assert.NoError(t, err) } diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index c315d52cfc..a8a9909be7 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -289,6 +289,18 @@ func (s *Server) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsReq return s.indexnode.GetMetrics(ctx, request) } +func (s *Server) CreateJobV2(ctx context.Context, request *indexpb.CreateJobV2Request) (*commonpb.Status, error) { + return s.indexnode.CreateJobV2(ctx, request) +} + +func (s *Server) QueryJobsV2(ctx context.Context, request *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) { + return s.indexnode.QueryJobsV2(ctx, request) +} + +func (s *Server) DropJobsV2(ctx context.Context, request *indexpb.DropJobsV2Request) (*commonpb.Status, error) { + return s.indexnode.DropJobsV2(ctx, request) +} + // NewServer create a new IndexNode grpc server. func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) { ctx1, cancel := context.WithCancel(ctx) diff --git a/internal/distributed/indexnode/service_test.go b/internal/distributed/indexnode/service_test.go index edfc175423..12b9af0b62 100644 --- a/internal/distributed/indexnode/service_test.go +++ b/internal/distributed/indexnode/service_test.go @@ -21,13 +21,15 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/indexnode" + "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -40,7 +42,13 @@ func TestIndexNodeServer(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, server) - inm := indexnode.NewIndexNodeMock() + inm := mocks.NewMockIndexNode(t) + inm.EXPECT().SetEtcdClient(mock.Anything).Return() + inm.EXPECT().SetAddress(mock.Anything).Return() + inm.EXPECT().Start().Return(nil) + inm.EXPECT().Init().Return(nil) + inm.EXPECT().Register().Return(nil) + inm.EXPECT().Stop().Return(nil) err = server.setServer(inm) assert.NoError(t, err) @@ -48,6 +56,11 @@ func TestIndexNodeServer(t *testing.T) { assert.NoError(t, err) t.Run("GetComponentStates", func(t *testing.T) { + inm.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + StateCode: commonpb.StateCode_Healthy, + }, + }, nil) req := &milvuspb.GetComponentStatesRequest{} states, err := server.GetComponentStates(ctx, req) assert.NoError(t, err) @@ -55,6 +68,9 @@ func TestIndexNodeServer(t *testing.T) { }) t.Run("GetStatisticsChannel", func(t *testing.T) { + inm.EXPECT().GetStatisticsChannel(mock.Anything, mock.Anything).Return(&milvuspb.StringResponse{ + Status: merr.Success(), + }, nil) req := &internalpb.GetStatisticsChannelRequest{} resp, err := server.GetStatisticsChannel(ctx, req) assert.NoError(t, err) @@ -62,6 +78,7 @@ func TestIndexNodeServer(t *testing.T) { }) t.Run("CreateJob", func(t *testing.T) { + inm.EXPECT().CreateJob(mock.Anything, mock.Anything).Return(merr.Success(), nil) req := &indexpb.CreateJobRequest{ ClusterID: "", BuildID: 0, @@ -74,6 +91,9 @@ func TestIndexNodeServer(t *testing.T) { }) t.Run("QueryJob", func(t *testing.T) { + inm.EXPECT().QueryJobs(mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ + Status: merr.Success(), + }, nil) req := &indexpb.QueryJobsRequest{} resp, err := server.QueryJobs(ctx, req) assert.NoError(t, err) @@ -81,6 +101,7 @@ func TestIndexNodeServer(t *testing.T) { }) t.Run("DropJobs", func(t *testing.T) { + inm.EXPECT().DropJobs(mock.Anything, mock.Anything).Return(merr.Success(), nil) req := &indexpb.DropJobsRequest{} resp, err := server.DropJobs(ctx, req) assert.NoError(t, err) @@ -88,6 +109,9 @@ func TestIndexNodeServer(t *testing.T) { }) t.Run("ShowConfigurations", func(t *testing.T) { + inm.EXPECT().ShowConfigurations(mock.Anything, mock.Anything).Return(&internalpb.ShowConfigurationsResponse{ + Status: merr.Success(), + }, nil) req := &internalpb.ShowConfigurationsRequest{ Pattern: "", } @@ -97,6 +121,9 @@ func TestIndexNodeServer(t *testing.T) { }) t.Run("GetMetrics", func(t *testing.T) { + inm.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ + Status: merr.Success(), + }, nil) req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) assert.NoError(t, err) resp, err := server.GetMetrics(ctx, req) @@ -105,12 +132,41 @@ func TestIndexNodeServer(t *testing.T) { }) t.Run("GetTaskSlots", func(t *testing.T) { + inm.EXPECT().GetJobStats(mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{ + Status: merr.Success(), + }, nil) req := &indexpb.GetJobStatsRequest{} resp, err := server.GetJobStats(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) + t.Run("CreateJobV2", func(t *testing.T) { + inm.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil) + req := &indexpb.CreateJobV2Request{} + resp, err := server.CreateJobV2(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + t.Run("QueryJobsV2", func(t *testing.T) { + inm.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).Return(&indexpb.QueryJobsV2Response{ + Status: merr.Success(), + }, nil) + req := &indexpb.QueryJobsV2Request{} + resp, err := server.QueryJobsV2(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("DropJobsV2", func(t *testing.T) { + inm.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil) + req := &indexpb.DropJobsV2Request{} + resp, err := server.DropJobsV2(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + err = server.Stop() assert.NoError(t, err) } diff --git a/internal/indexnode/indexnode.go b/internal/indexnode/indexnode.go index 2dc7ab7580..d904a4a296 100644 --- a/internal/indexnode/indexnode.go +++ b/internal/indexnode/indexnode.go @@ -17,7 +17,7 @@ package indexnode /* -#cgo pkg-config: milvus_common milvus_indexbuilder milvus_segcore +#cgo pkg-config: milvus_common milvus_indexbuilder milvus_clustering milvus_segcore #include #include @@ -105,9 +105,10 @@ type IndexNode struct { etcdCli *clientv3.Client address string - initOnce sync.Once - stateLock sync.Mutex - tasks map[taskKey]*taskInfo + initOnce sync.Once + stateLock sync.Mutex + indexTasks map[taskKey]*indexTaskInfo + analyzeTasks map[taskKey]*analyzeTaskInfo } // NewIndexNode creates a new IndexNode component. @@ -120,7 +121,8 @@ func NewIndexNode(ctx context.Context, factory dependency.Factory) *IndexNode { loopCancel: cancel, factory: factory, storageFactory: NewChunkMgrFactory(), - tasks: map[taskKey]*taskInfo{}, + indexTasks: make(map[taskKey]*indexTaskInfo), + analyzeTasks: make(map[taskKey]*analyzeTaskInfo), lifetime: lifetime.NewLifetime(commonpb.StateCode_Abnormal), } sc := NewTaskScheduler(b.loopCtx) @@ -251,10 +253,16 @@ func (i *IndexNode) Stop() error { i.lifetime.Wait() log.Info("Index node abnormal") // cleanup all running tasks - deletedTasks := i.deleteAllTasks() - for _, task := range deletedTasks { - if task.cancel != nil { - task.cancel() + deletedIndexTasks := i.deleteAllIndexTasks() + for _, t := range deletedIndexTasks { + if t.cancel != nil { + t.cancel() + } + } + deletedAnalyzeTasks := i.deleteAllAnalyzeTasks() + for _, t := range deletedAnalyzeTasks { + if t.cancel != nil { + t.cancel() } } if i.sched != nil { diff --git a/internal/indexnode/indexnode_mock.go b/internal/indexnode/indexnode_mock.go index fc1b9249cc..738d3386e2 100644 --- a/internal/indexnode/indexnode_mock.go +++ b/internal/indexnode/indexnode_mock.go @@ -18,7 +18,9 @@ package indexnode import ( "context" + "fmt" + "github.com/cockroachdb/errors" clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -52,6 +54,9 @@ type Mock struct { CallQueryJobs func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) CallDropJobs func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error) CallGetJobStats func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) + CallCreateJobV2 func(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error) + CallQueryJobV2 func(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) + CallDropJobV2 func(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error) CallGetMetrics func(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) CallShowConfigurations func(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) @@ -114,6 +119,62 @@ func NewIndexNodeMock() *Mock { CallDropJobs: func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error) { return merr.Success(), nil }, + CallCreateJobV2: func(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error) { + return merr.Success(), nil + }, + CallQueryJobV2: func(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) { + switch req.GetJobType() { + case indexpb.JobType_JobTypeIndexJob: + results := make([]*indexpb.IndexTaskInfo, 0) + for _, buildID := range req.GetTaskIDs() { + results = append(results, &indexpb.IndexTaskInfo{ + BuildID: buildID, + State: commonpb.IndexState_Finished, + IndexFileKeys: []string{}, + SerializedSize: 1024, + FailReason: "", + CurrentIndexVersion: 1, + IndexStoreVersion: 1, + }) + } + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: req.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_IndexJobResults{ + IndexJobResults: &indexpb.IndexJobResults{ + Results: results, + }, + }, + }, nil + case indexpb.JobType_JobTypeAnalyzeJob: + results := make([]*indexpb.AnalyzeResult, 0) + for _, taskID := range req.GetTaskIDs() { + results = append(results, &indexpb.AnalyzeResult{ + TaskID: taskID, + State: indexpb.JobState_JobStateFinished, + CentroidsFile: fmt.Sprintf("%d/stats_file", taskID), + FailReason: "", + }) + } + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: req.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{ + AnalyzeJobResults: &indexpb.AnalyzeResults{ + Results: results, + }, + }, + }, nil + default: + return &indexpb.QueryJobsV2Response{ + Status: merr.Status(errors.New("unknown job type")), + ClusterID: req.GetClusterID(), + }, nil + } + }, + CallDropJobV2: func(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error) { + return merr.Success(), nil + }, CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { return &indexpb.GetJobStatsResponse{ Status: merr.Success(), @@ -201,6 +262,18 @@ func (m *Mock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) return m.CallGetMetrics(ctx, req) } +func (m *Mock) CreateJobV2(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error) { + return m.CallCreateJobV2(ctx, req) +} + +func (m *Mock) QueryJobsV2(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) { + return m.CallQueryJobV2(ctx, req) +} + +func (m *Mock) DropJobsV2(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error) { + return m.CallDropJobV2(ctx, req) +} + // ShowConfigurations returns the configurations of Mock indexNode matching req.Pattern func (m *Mock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { return m.CallShowConfigurations(ctx, req) diff --git a/internal/indexnode/indexnode_service.go b/internal/indexnode/indexnode_service.go index ff54fb02b9..e1eee6280c 100644 --- a/internal/indexnode/indexnode_service.go +++ b/internal/indexnode/indexnode_service.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -77,7 +78,7 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.TotalLabel).Inc() taskCtx, taskCancel := context.WithCancel(i.loopCtx) - if oldInfo := i.loadOrStoreTask(req.GetClusterID(), req.GetBuildID(), &taskInfo{ + if oldInfo := i.loadOrStoreIndexTask(req.GetClusterID(), req.GetBuildID(), &indexTaskInfo{ cancel: taskCancel, state: commonpb.IndexState_InProgress, }); oldInfo != nil { @@ -92,7 +93,7 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest zap.String("accessKey", req.GetStorageConfig().GetAccessKeyID()), zap.Error(err), ) - i.deleteTaskInfos(ctx, []taskKey{{ClusterID: req.GetClusterID(), BuildID: req.GetBuildID()}}) + i.deleteIndexTaskInfos(ctx, []taskKey{{ClusterID: req.GetClusterID(), BuildID: req.GetBuildID()}}) metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc() return merr.Status(err), nil } @@ -103,7 +104,7 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest task = newIndexBuildTask(taskCtx, taskCancel, req, cm, i) } ret := merr.Success() - if err := i.sched.IndexBuildQueue.Enqueue(task); err != nil { + if err := i.sched.TaskQueue.Enqueue(task); err != nil { log.Warn("IndexNode failed to schedule", zap.Error(err)) ret = merr.Status(err) @@ -127,10 +128,10 @@ func (i *IndexNode) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest }, nil } defer i.lifetime.Done() - infos := make(map[UniqueID]*taskInfo) - i.foreachTaskInfo(func(ClusterID string, buildID UniqueID, info *taskInfo) { + infos := make(map[UniqueID]*indexTaskInfo) + i.foreachIndexTaskInfo(func(ClusterID string, buildID UniqueID, info *indexTaskInfo) { if ClusterID == req.GetClusterID() { - infos[buildID] = &taskInfo{ + infos[buildID] = &indexTaskInfo{ state: info.state, fileKeys: common.CloneStringList(info.fileKeys), serializedSize: info.serializedSize, @@ -183,7 +184,7 @@ func (i *IndexNode) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) for _, buildID := range req.GetBuildIDs() { keys = append(keys, taskKey{ClusterID: req.GetClusterID(), BuildID: buildID}) } - infos := i.deleteTaskInfos(ctx, keys) + infos := i.deleteIndexTaskInfos(ctx, keys) for _, info := range infos { if info.cancel != nil { info.cancel() @@ -203,7 +204,8 @@ func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsReq }, nil } defer i.lifetime.Done() - unissued, active := i.sched.IndexBuildQueue.GetTaskNum() + unissued, active := i.sched.TaskQueue.GetTaskNum() + slots := 0 if i.sched.buildParallel > unissued+active { slots = i.sched.buildParallel - unissued - active @@ -271,3 +273,250 @@ func (i *IndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequ Status: merr.Status(merr.WrapErrMetricNotFound(metricType)), }, nil } + +func (i *IndexNode) CreateJobV2(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error) { + log := log.Ctx(ctx).With( + zap.String("clusterID", req.GetClusterID()), zap.Int64("taskID", req.GetTaskID()), + zap.String("jobType", req.GetJobType().String()), + ) + + if err := i.lifetime.Add(merr.IsHealthy); err != nil { + log.Warn("index node not ready", + zap.Error(err), + ) + return merr.Status(err), nil + } + defer i.lifetime.Done() + + log.Info("IndexNode receive CreateJob request...") + + switch req.GetJobType() { + case indexpb.JobType_JobTypeIndexJob: + indexRequest := req.GetIndexRequest() + log.Info("IndexNode building index ...", + zap.Int64("indexID", indexRequest.GetIndexID()), + zap.String("indexName", indexRequest.GetIndexName()), + zap.String("indexFilePrefix", indexRequest.GetIndexFilePrefix()), + zap.Int64("indexVersion", indexRequest.GetIndexVersion()), + zap.Strings("dataPaths", indexRequest.GetDataPaths()), + zap.Any("typeParams", indexRequest.GetTypeParams()), + zap.Any("indexParams", indexRequest.GetIndexParams()), + zap.Int64("numRows", indexRequest.GetNumRows()), + zap.Int32("current_index_version", indexRequest.GetCurrentIndexVersion()), + zap.String("storePath", indexRequest.GetStorePath()), + zap.Int64("storeVersion", indexRequest.GetStoreVersion()), + zap.String("indexStorePath", indexRequest.GetIndexStorePath()), + zap.Int64("dim", indexRequest.GetDim())) + taskCtx, taskCancel := context.WithCancel(i.loopCtx) + if oldInfo := i.loadOrStoreIndexTask(indexRequest.GetClusterID(), indexRequest.GetBuildID(), &indexTaskInfo{ + cancel: taskCancel, + state: commonpb.IndexState_InProgress, + }); oldInfo != nil { + err := merr.WrapErrIndexDuplicate(indexRequest.GetIndexName(), "building index task existed") + log.Warn("duplicated index build task", zap.Error(err)) + metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc() + return merr.Status(err), nil + } + cm, err := i.storageFactory.NewChunkManager(i.loopCtx, indexRequest.GetStorageConfig()) + if err != nil { + log.Error("create chunk manager failed", zap.String("bucket", indexRequest.GetStorageConfig().GetBucketName()), + zap.String("accessKey", indexRequest.GetStorageConfig().GetAccessKeyID()), + zap.Error(err), + ) + i.deleteIndexTaskInfos(ctx, []taskKey{{ClusterID: indexRequest.GetClusterID(), BuildID: indexRequest.GetBuildID()}}) + metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc() + return merr.Status(err), nil + } + var task task + if Params.CommonCfg.EnableStorageV2.GetAsBool() { + task = newIndexBuildTaskV2(taskCtx, taskCancel, indexRequest, i) + } else { + task = newIndexBuildTask(taskCtx, taskCancel, indexRequest, cm, i) + } + ret := merr.Success() + if err := i.sched.TaskQueue.Enqueue(task); err != nil { + log.Warn("IndexNode failed to schedule", + zap.Error(err)) + ret = merr.Status(err) + metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.FailLabel).Inc() + return ret, nil + } + metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SuccessLabel).Inc() + log.Info("IndexNode index job enqueued successfully", + zap.String("indexName", indexRequest.GetIndexName())) + return ret, nil + case indexpb.JobType_JobTypeAnalyzeJob: + analyzeRequest := req.GetAnalyzeRequest() + log.Info("receive analyze job", zap.Int64("collectionID", analyzeRequest.GetCollectionID()), + zap.Int64("partitionID", analyzeRequest.GetPartitionID()), + zap.Int64("fieldID", analyzeRequest.GetFieldID()), + zap.String("fieldName", analyzeRequest.GetFieldName()), + zap.String("dataType", analyzeRequest.GetFieldType().String()), + zap.Int64("version", analyzeRequest.GetVersion()), + zap.Int64("dim", analyzeRequest.GetDim()), + zap.Float64("trainSizeRatio", analyzeRequest.GetMaxTrainSizeRatio()), + zap.Int64("numClusters", analyzeRequest.GetNumClusters()), + ) + taskCtx, taskCancel := context.WithCancel(i.loopCtx) + if oldInfo := i.loadOrStoreAnalyzeTask(analyzeRequest.GetClusterID(), analyzeRequest.GetTaskID(), &analyzeTaskInfo{ + cancel: taskCancel, + state: indexpb.JobState_JobStateInProgress, + }); oldInfo != nil { + err := merr.WrapErrIndexDuplicate("", "analyze task already existed") + log.Warn("duplicated analyze task", zap.Error(err)) + return merr.Status(err), nil + } + t := &analyzeTask{ + ident: fmt.Sprintf("%s/%d", analyzeRequest.GetClusterID(), analyzeRequest.GetTaskID()), + ctx: taskCtx, + cancel: taskCancel, + req: analyzeRequest, + node: i, + tr: timerecord.NewTimeRecorder(fmt.Sprintf("ClusterID: %s, IndexBuildID: %d", req.GetClusterID(), req.GetTaskID())), + } + ret := merr.Success() + if err := i.sched.TaskQueue.Enqueue(t); err != nil { + log.Warn("IndexNode failed to schedule", zap.Error(err)) + ret = merr.Status(err) + return ret, nil + } + log.Info("IndexNode analyze job enqueued successfully") + return ret, nil + default: + log.Warn("IndexNode receive unknown type job") + return merr.Status(fmt.Errorf("IndexNode receive unknown type job with taskID: %d", req.GetTaskID())), nil + } +} + +func (i *IndexNode) QueryJobsV2(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) { + log := log.Ctx(ctx).With( + zap.String("clusterID", req.GetClusterID()), zap.Int64s("taskIDs", req.GetTaskIDs()), + ).WithRateGroup("QueryResult", 1, 60) + + if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil { + log.Warn("IndexNode not ready", zap.Error(err)) + return &indexpb.QueryJobsV2Response{ + Status: merr.Status(err), + }, nil + } + defer i.lifetime.Done() + + switch req.GetJobType() { + case indexpb.JobType_JobTypeIndexJob: + infos := make(map[UniqueID]*indexTaskInfo) + i.foreachIndexTaskInfo(func(ClusterID string, buildID UniqueID, info *indexTaskInfo) { + if ClusterID == req.GetClusterID() { + infos[buildID] = &indexTaskInfo{ + state: info.state, + fileKeys: common.CloneStringList(info.fileKeys), + serializedSize: info.serializedSize, + failReason: info.failReason, + currentIndexVersion: info.currentIndexVersion, + indexStoreVersion: info.indexStoreVersion, + } + } + }) + results := make([]*indexpb.IndexTaskInfo, 0, len(req.GetTaskIDs())) + for i, buildID := range req.GetTaskIDs() { + results = append(results, &indexpb.IndexTaskInfo{ + BuildID: buildID, + State: commonpb.IndexState_IndexStateNone, + IndexFileKeys: nil, + SerializedSize: 0, + }) + if info, ok := infos[buildID]; ok { + results[i].State = info.state + results[i].IndexFileKeys = info.fileKeys + results[i].SerializedSize = info.serializedSize + results[i].FailReason = info.failReason + results[i].CurrentIndexVersion = info.currentIndexVersion + results[i].IndexStoreVersion = info.indexStoreVersion + } + } + log.Debug("query index jobs result success", zap.Any("results", results)) + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: req.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_IndexJobResults{ + IndexJobResults: &indexpb.IndexJobResults{ + Results: results, + }, + }, + }, nil + case indexpb.JobType_JobTypeAnalyzeJob: + results := make([]*indexpb.AnalyzeResult, 0, len(req.GetTaskIDs())) + for _, taskID := range req.GetTaskIDs() { + info := i.getAnalyzeTaskInfo(req.GetClusterID(), taskID) + if info != nil { + results = append(results, &indexpb.AnalyzeResult{ + TaskID: taskID, + State: info.state, + FailReason: info.failReason, + CentroidsFile: info.centroidsFile, + }) + } + } + log.Debug("query analyze jobs result success", zap.Any("results", results)) + return &indexpb.QueryJobsV2Response{ + Status: merr.Success(), + ClusterID: req.GetClusterID(), + Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{ + AnalyzeJobResults: &indexpb.AnalyzeResults{ + Results: results, + }, + }, + }, nil + default: + log.Warn("IndexNode receive querying unknown type jobs") + return &indexpb.QueryJobsV2Response{ + Status: merr.Status(fmt.Errorf("IndexNode receive querying unknown type jobs")), + }, nil + } +} + +func (i *IndexNode) DropJobsV2(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error) { + log := log.Ctx(ctx).With(zap.String("clusterID", req.GetClusterID()), + zap.Int64s("taskIDs", req.GetTaskIDs()), + zap.String("jobType", req.GetJobType().String()), + ) + + if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil { + log.Warn("IndexNode not ready", zap.Error(err)) + return merr.Status(err), nil + } + defer i.lifetime.Done() + + log.Info("IndexNode receive DropJobs request") + + switch req.GetJobType() { + case indexpb.JobType_JobTypeIndexJob: + keys := make([]taskKey, 0, len(req.GetTaskIDs())) + for _, buildID := range req.GetTaskIDs() { + keys = append(keys, taskKey{ClusterID: req.GetClusterID(), BuildID: buildID}) + } + infos := i.deleteIndexTaskInfos(ctx, keys) + for _, info := range infos { + if info.cancel != nil { + info.cancel() + } + } + log.Info("drop index build jobs success") + return merr.Success(), nil + case indexpb.JobType_JobTypeAnalyzeJob: + keys := make([]taskKey, 0, len(req.GetTaskIDs())) + for _, taskID := range req.GetTaskIDs() { + keys = append(keys, taskKey{ClusterID: req.GetClusterID(), BuildID: taskID}) + } + infos := i.deleteAnalyzeTaskInfos(ctx, keys) + for _, info := range infos { + if info.cancel != nil { + info.cancel() + } + } + log.Info("drop analyze jobs success") + return merr.Success(), nil + default: + log.Warn("IndexNode receive dropping unknown type jobs") + return merr.Status(fmt.Errorf("IndexNode receive dropping unknown type jobs")), nil + } +} diff --git a/internal/indexnode/indexnode_service_test.go b/internal/indexnode/indexnode_service_test.go index 255551d3e2..a41cbb4d4f 100644 --- a/internal/indexnode/indexnode_service_test.go +++ b/internal/indexnode/indexnode_service_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -100,3 +101,132 @@ func TestMockFieldData(t *testing.T) { chunkMgr.mockFieldData(100000, 8, 0, 0, 1) } + +type IndexNodeServiceSuite struct { + suite.Suite + cluster string + collectionID int64 + partitionID int64 + taskID int64 + fieldID int64 + segmentID int64 +} + +func (suite *IndexNodeServiceSuite) SetupTest() { + suite.cluster = "test_cluster" + suite.collectionID = 100 + suite.partitionID = 102 + suite.taskID = 11111 + suite.fieldID = 103 + suite.segmentID = 104 +} + +func (suite *IndexNodeServiceSuite) Test_AbnormalIndexNode() { + in, err := NewMockIndexNodeComponent(context.TODO()) + suite.NoError(err) + suite.Nil(in.Stop()) + + ctx := context.TODO() + status, err := in.CreateJob(ctx, &indexpb.CreateJobRequest{}) + suite.NoError(err) + suite.ErrorIs(merr.Error(status), merr.ErrServiceNotReady) + + qresp, err := in.QueryJobs(ctx, &indexpb.QueryJobsRequest{}) + suite.NoError(err) + suite.ErrorIs(merr.Error(qresp.GetStatus()), merr.ErrServiceNotReady) + + status, err = in.DropJobs(ctx, &indexpb.DropJobsRequest{}) + suite.NoError(err) + suite.ErrorIs(merr.Error(status), merr.ErrServiceNotReady) + + jobNumRsp, err := in.GetJobStats(ctx, &indexpb.GetJobStatsRequest{}) + suite.NoError(err) + suite.ErrorIs(merr.Error(jobNumRsp.GetStatus()), merr.ErrServiceNotReady) + + metricsResp, err := in.GetMetrics(ctx, &milvuspb.GetMetricsRequest{}) + err = merr.CheckRPCCall(metricsResp, err) + suite.ErrorIs(err, merr.ErrServiceNotReady) + + configurationResp, err := in.ShowConfigurations(ctx, &internalpb.ShowConfigurationsRequest{}) + err = merr.CheckRPCCall(configurationResp, err) + suite.ErrorIs(err, merr.ErrServiceNotReady) + + status, err = in.CreateJobV2(ctx, &indexpb.CreateJobV2Request{}) + err = merr.CheckRPCCall(status, err) + suite.ErrorIs(err, merr.ErrServiceNotReady) + + queryAnalyzeResultResp, err := in.QueryJobsV2(ctx, &indexpb.QueryJobsV2Request{}) + err = merr.CheckRPCCall(queryAnalyzeResultResp, err) + suite.ErrorIs(err, merr.ErrServiceNotReady) + + dropAnalyzeTasksResp, err := in.DropJobsV2(ctx, &indexpb.DropJobsV2Request{}) + err = merr.CheckRPCCall(dropAnalyzeTasksResp, err) + suite.ErrorIs(err, merr.ErrServiceNotReady) +} + +func (suite *IndexNodeServiceSuite) Test_Method() { + ctx := context.TODO() + in, err := NewMockIndexNodeComponent(context.TODO()) + suite.NoError(err) + suite.NoError(in.Stop()) + + in.UpdateStateCode(commonpb.StateCode_Healthy) + + suite.Run("CreateJobV2", func() { + req := &indexpb.AnalyzeRequest{ + ClusterID: suite.cluster, + TaskID: suite.taskID, + CollectionID: suite.collectionID, + PartitionID: suite.partitionID, + FieldID: suite.fieldID, + SegmentStats: map[int64]*indexpb.SegmentStats{ + suite.segmentID: { + ID: suite.segmentID, + NumRows: 1024, + LogIDs: []int64{1, 2, 3}, + }, + }, + Version: 1, + StorageConfig: nil, + } + + resp, err := in.CreateJobV2(ctx, &indexpb.CreateJobV2Request{ + ClusterID: suite.cluster, + TaskID: suite.taskID, + JobType: indexpb.JobType_JobTypeAnalyzeJob, + Request: &indexpb.CreateJobV2Request_AnalyzeRequest{ + AnalyzeRequest: req, + }, + }) + err = merr.CheckRPCCall(resp, err) + suite.NoError(err) + }) + + suite.Run("QueryJobsV2", func() { + req := &indexpb.QueryJobsV2Request{ + ClusterID: suite.cluster, + TaskIDs: []int64{suite.taskID}, + JobType: indexpb.JobType_JobTypeIndexJob, + } + + resp, err := in.QueryJobsV2(ctx, req) + err = merr.CheckRPCCall(resp, err) + suite.NoError(err) + }) + + suite.Run("DropJobsV2", func() { + req := &indexpb.DropJobsV2Request{ + ClusterID: suite.cluster, + TaskIDs: []int64{suite.taskID}, + JobType: indexpb.JobType_JobTypeIndexJob, + } + + resp, err := in.DropJobsV2(ctx, req) + err = merr.CheckRPCCall(resp, err) + suite.NoError(err) + }) +} + +func Test_IndexNodeServiceSuite(t *testing.T) { + suite.Run(t, new(IndexNodeServiceSuite)) +} diff --git a/internal/indexnode/indexnode_test.go b/internal/indexnode/indexnode_test.go index 20aadd172b..156d99d716 100644 --- a/internal/indexnode/indexnode_test.go +++ b/internal/indexnode/indexnode_test.go @@ -110,17 +110,17 @@ func TestIndexTaskWhenStoppingNode(t *testing.T) { paramtable.Init() in := NewIndexNode(ctx, factory) - in.loadOrStoreTask("cluster-1", 1, &taskInfo{ + in.loadOrStoreIndexTask("cluster-1", 1, &indexTaskInfo{ state: commonpb.IndexState_InProgress, }) - in.loadOrStoreTask("cluster-2", 2, &taskInfo{ + in.loadOrStoreIndexTask("cluster-2", 2, &indexTaskInfo{ state: commonpb.IndexState_Finished, }) assert.True(t, in.hasInProgressTask()) go func() { time.Sleep(2 * time.Second) - in.storeTaskState("cluster-1", 1, commonpb.IndexState_Finished, "") + in.storeIndexTaskState("cluster-1", 1, commonpb.IndexState_Finished, "") }() noTaskChan := make(chan struct{}) go func() { diff --git a/internal/indexnode/task.go b/internal/indexnode/task.go index 177f1282c5..003d2621c1 100644 --- a/internal/indexnode/task.go +++ b/internal/indexnode/task.go @@ -19,28 +19,9 @@ package indexnode import ( "context" "fmt" - "strconv" - "strings" - "time" - "github.com/cockroachdb/errors" - "go.uber.org/zap" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/indexcgopb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/util/indexcgowrapper" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" - "github.com/milvus-io/milvus/pkg/util/indexparams" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/metautil" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/timerecord" ) var ( @@ -50,550 +31,14 @@ var ( type Blob = storage.Blob -type taskInfo struct { - cancel context.CancelFunc - state commonpb.IndexState - fileKeys []string - serializedSize uint64 - failReason string - currentIndexVersion int32 - indexStoreVersion int64 -} - type task interface { Ctx() context.Context Name() string - Prepare(context.Context) error - BuildIndex(context.Context) error - SaveIndexFiles(context.Context) error OnEnqueue(context.Context) error - SetState(state commonpb.IndexState, failReason string) - GetState() commonpb.IndexState + SetState(state indexpb.JobState, failReason string) + GetState() indexpb.JobState + PreExecute(context.Context) error + Execute(context.Context) error + PostExecute(context.Context) error Reset() } - -type indexBuildTaskV2 struct { - *indexBuildTask -} - -func newIndexBuildTaskV2(ctx context.Context, - cancel context.CancelFunc, - req *indexpb.CreateJobRequest, - node *IndexNode, -) *indexBuildTaskV2 { - t := &indexBuildTaskV2{ - indexBuildTask: &indexBuildTask{ - ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()), - cancel: cancel, - ctx: ctx, - req: req, - tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())), - node: node, - }, - } - - t.parseParams() - return t -} - -func (it *indexBuildTaskV2) parseParams() { - // fill field for requests before v2.5.0 - if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None { - it.req.Field = &schemapb.FieldSchema{ - FieldID: it.req.GetFieldID(), - Name: it.req.GetFieldName(), - DataType: it.req.GetFieldType(), - } - } -} - -func (it *indexBuildTaskV2) BuildIndex(ctx context.Context) error { - log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), - zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), - zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) - - indexType := it.newIndexParams[common.IndexTypeKey] - if indexType == indexparamcheck.IndexDISKANN { - // check index node support disk index - if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { - log.Warn("IndexNode don't support build disk index", - zap.String("index type", it.newIndexParams[common.IndexTypeKey]), - zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool())) - return merr.WrapErrIndexNotSupported("disk index") - } - - // check load size and size of field data - localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue()) - if err != nil { - log.Warn("IndexNode get local used size failed") - return err - } - fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) - if err != nil { - log.Warn("IndexNode get local used size failed") - return err - } - usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize - maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat()) - - if usedLocalSizeWhenBuild > maxUsedLocalSize { - log.Warn("IndexNode don't has enough disk size to build disk ann index", - zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild), - zap.Int64("maxUsedLocalSize", maxUsedLocalSize)) - return merr.WrapErrServiceDiskLimitExceeded(float32(usedLocalSizeWhenBuild), float32(maxUsedLocalSize)) - } - - err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize)) - if err != nil { - log.Warn("failed to fill disk index params", zap.Error(err)) - return err - } - } - - storageConfig := &indexcgopb.StorageConfig{ - Address: it.req.GetStorageConfig().GetAddress(), - AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), - SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), - UseSSL: it.req.GetStorageConfig().GetUseSSL(), - BucketName: it.req.GetStorageConfig().GetBucketName(), - RootPath: it.req.GetStorageConfig().GetRootPath(), - UseIAM: it.req.GetStorageConfig().GetUseIAM(), - IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), - StorageType: it.req.GetStorageConfig().GetStorageType(), - UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), - Region: it.req.GetStorageConfig().GetRegion(), - CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), - RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), - SslCACert: it.req.GetStorageConfig().GetSslCACert(), - } - - optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) - for _, optField := range it.req.GetOptionalScalarFields() { - optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ - FieldID: optField.GetFieldID(), - FieldName: optField.GetFieldName(), - FieldType: optField.GetFieldType(), - DataPaths: optField.GetDataPaths(), - }) - } - - buildIndexParams := &indexcgopb.BuildIndexInfo{ - ClusterID: it.req.GetClusterID(), - BuildID: it.req.GetBuildID(), - CollectionID: it.req.GetCollectionID(), - PartitionID: it.req.GetPartitionID(), - SegmentID: it.req.GetSegmentID(), - IndexVersion: it.req.GetIndexVersion(), - CurrentIndexVersion: it.req.GetCurrentIndexVersion(), - NumRows: it.req.GetNumRows(), - Dim: it.req.GetDim(), - IndexFilePrefix: it.req.GetIndexFilePrefix(), - InsertFiles: it.req.GetDataPaths(), - FieldSchema: it.req.GetField(), - StorageConfig: storageConfig, - IndexParams: mapToKVPairs(it.newIndexParams), - TypeParams: mapToKVPairs(it.newTypeParams), - StorePath: it.req.GetStorePath(), - StoreVersion: it.req.GetStoreVersion(), - IndexStorePath: it.req.GetIndexStorePath(), - OptFields: optFields, - } - - var err error - it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexParams) - if err != nil { - if it.index != nil && it.index.CleanLocalData() != nil { - log.Warn("failed to clean cached data on disk after build index failed") - } - log.Warn("failed to build index", zap.Error(err)) - return err - } - - buildIndexLatency := it.tr.RecordSpan() - metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(buildIndexLatency.Milliseconds())) - - log.Info("Successfully build index") - return nil -} - -func (it *indexBuildTaskV2) SaveIndexFiles(ctx context.Context) error { - log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), - zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), - zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) - - gcIndex := func() { - if err := it.index.Delete(); err != nil { - log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err)) - } - } - version, err := it.index.UpLoadV2() - if err != nil { - log.Warn("failed to upload index", zap.Error(err)) - gcIndex() - return err - } - - encodeIndexFileDur := it.tr.Record("index serialize and upload done") - metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds()) - - // early release index for gc, and we can ensure that Delete is idempotent. - gcIndex() - - // use serialized size before encoding - var serializedSize uint64 - saveFileKeys := make([]string, 0) - - it.node.storeIndexFilesAndStatisticV2(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion(), version) - log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys)) - saveIndexFileDur := it.tr.RecordSpan() - metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds()) - it.tr.Elapse("index building all done") - log.Info("Successfully save index files") - return nil -} - -// IndexBuildTask is used to record the information of the index tasks. -type indexBuildTask struct { - ident string - cancel context.CancelFunc - ctx context.Context - - cm storage.ChunkManager - index indexcgowrapper.CodecIndex - req *indexpb.CreateJobRequest - newTypeParams map[string]string - newIndexParams map[string]string - tr *timerecord.TimeRecorder - queueDur time.Duration - node *IndexNode -} - -func newIndexBuildTask(ctx context.Context, - cancel context.CancelFunc, - req *indexpb.CreateJobRequest, - cm storage.ChunkManager, - node *IndexNode, -) *indexBuildTask { - t := &indexBuildTask{ - ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()), - cancel: cancel, - ctx: ctx, - cm: cm, - req: req, - tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())), - node: node, - } - - t.parseParams() - return t -} - -func (it *indexBuildTask) parseParams() { - // fill field for requests before v2.5.0 - if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None { - it.req.Field = &schemapb.FieldSchema{ - FieldID: it.req.GetFieldID(), - Name: it.req.GetFieldName(), - DataType: it.req.GetFieldType(), - } - } -} - -func (it *indexBuildTask) Reset() { - it.ident = "" - it.cancel = nil - it.ctx = nil - it.cm = nil - it.index = nil - it.req = nil - it.newTypeParams = nil - it.newIndexParams = nil - it.tr = nil - it.node = nil -} - -// Ctx is the context of index tasks. -func (it *indexBuildTask) Ctx() context.Context { - return it.ctx -} - -// Name is the name of task to build index. -func (it *indexBuildTask) Name() string { - return it.ident -} - -func (it *indexBuildTask) SetState(state commonpb.IndexState, failReason string) { - it.node.storeTaskState(it.req.GetClusterID(), it.req.GetBuildID(), state, failReason) -} - -func (it *indexBuildTask) GetState() commonpb.IndexState { - return it.node.loadTaskState(it.req.GetClusterID(), it.req.GetBuildID()) -} - -// OnEnqueue enqueues indexing tasks. -func (it *indexBuildTask) OnEnqueue(ctx context.Context) error { - it.queueDur = 0 - it.tr.RecordSpan() - log.Ctx(ctx).Info("IndexNode IndexBuilderTask Enqueue", zap.Int64("buildID", it.req.GetBuildID()), - zap.Int64("segmentID", it.req.GetSegmentID())) - return nil -} - -func (it *indexBuildTask) Prepare(ctx context.Context) error { - it.queueDur = it.tr.RecordSpan() - log.Ctx(ctx).Info("Begin to prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()), - zap.Int64("Collection", it.req.GetCollectionID()), zap.Int64("SegmentID", it.req.GetSegmentID())) - - typeParams := make(map[string]string) - indexParams := make(map[string]string) - - if len(it.req.DataPaths) == 0 { - for _, id := range it.req.GetDataIds() { - path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetField().GetFieldID(), id) - it.req.DataPaths = append(it.req.DataPaths, path) - } - } - - if it.req.OptionalScalarFields != nil { - for _, optFields := range it.req.GetOptionalScalarFields() { - if len(optFields.DataPaths) == 0 { - for _, id := range optFields.DataIds { - path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), optFields.FieldID, id) - optFields.DataPaths = append(optFields.DataPaths, path) - } - } - } - } - - // type params can be removed - for _, kvPair := range it.req.GetTypeParams() { - key, value := kvPair.GetKey(), kvPair.GetValue() - typeParams[key] = value - indexParams[key] = value - } - - for _, kvPair := range it.req.GetIndexParams() { - key, value := kvPair.GetKey(), kvPair.GetValue() - // knowhere would report error if encountered the unknown key, - // so skip this - if key == common.MmapEnabledKey { - continue - } - indexParams[key] = value - } - it.newTypeParams = typeParams - it.newIndexParams = indexParams - - if it.req.GetDim() == 0 { - // fill dim for requests before v2.4.0 - if dimStr, ok := typeParams[common.DimKey]; ok { - var err error - it.req.Dim, err = strconv.ParseInt(dimStr, 10, 64) - if err != nil { - log.Ctx(ctx).Error("parse dimesion failed", zap.Error(err)) - // ignore error - } - } - } - - if it.req.GetCollectionID() == 0 { - err := it.parseFieldMetaFromBinlog(ctx) - if err != nil { - log.Ctx(ctx).Warn("parse field meta from binlog failed", zap.Error(err)) - return err - } - } - - log.Ctx(ctx).Info("Successfully prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()), - zap.Int64("collectionID", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID())) - return nil -} - -func (it *indexBuildTask) BuildIndex(ctx context.Context) error { - log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), - zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), - zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) - - indexType := it.newIndexParams[common.IndexTypeKey] - if indexType == indexparamcheck.IndexDISKANN { - // check index node support disk index - if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { - log.Warn("IndexNode don't support build disk index", - zap.String("index type", it.newIndexParams[common.IndexTypeKey]), - zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool())) - return errors.New("index node don't support build disk index") - } - - // check load size and size of field data - localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue()) - if err != nil { - log.Warn("IndexNode get local used size failed") - return err - } - fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) - if err != nil { - log.Warn("IndexNode get local used size failed") - return err - } - usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize - maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat()) - - if usedLocalSizeWhenBuild > maxUsedLocalSize { - log.Warn("IndexNode don't has enough disk size to build disk ann index", - zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild), - zap.Int64("maxUsedLocalSize", maxUsedLocalSize)) - return errors.New("index node don't has enough disk size to build disk ann index") - } - - err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize)) - if err != nil { - log.Warn("failed to fill disk index params", zap.Error(err)) - return err - } - } - - storageConfig := &indexcgopb.StorageConfig{ - Address: it.req.GetStorageConfig().GetAddress(), - AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), - SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), - UseSSL: it.req.GetStorageConfig().GetUseSSL(), - BucketName: it.req.GetStorageConfig().GetBucketName(), - RootPath: it.req.GetStorageConfig().GetRootPath(), - UseIAM: it.req.GetStorageConfig().GetUseIAM(), - IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), - StorageType: it.req.GetStorageConfig().GetStorageType(), - UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), - Region: it.req.GetStorageConfig().GetRegion(), - CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), - RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), - SslCACert: it.req.GetStorageConfig().GetSslCACert(), - } - - optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) - for _, optField := range it.req.GetOptionalScalarFields() { - optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ - FieldID: optField.GetFieldID(), - FieldName: optField.GetFieldName(), - FieldType: optField.GetFieldType(), - DataPaths: optField.GetDataPaths(), - }) - } - - buildIndexParams := &indexcgopb.BuildIndexInfo{ - ClusterID: it.req.GetClusterID(), - BuildID: it.req.GetBuildID(), - CollectionID: it.req.GetCollectionID(), - PartitionID: it.req.GetPartitionID(), - SegmentID: it.req.GetSegmentID(), - IndexVersion: it.req.GetIndexVersion(), - CurrentIndexVersion: it.req.GetCurrentIndexVersion(), - NumRows: it.req.GetNumRows(), - Dim: it.req.GetDim(), - IndexFilePrefix: it.req.GetIndexFilePrefix(), - InsertFiles: it.req.GetDataPaths(), - FieldSchema: it.req.GetField(), - StorageConfig: storageConfig, - IndexParams: mapToKVPairs(it.newIndexParams), - TypeParams: mapToKVPairs(it.newTypeParams), - StorePath: it.req.GetStorePath(), - StoreVersion: it.req.GetStoreVersion(), - IndexStorePath: it.req.GetIndexStorePath(), - OptFields: optFields, - } - - log.Info("debug create index", zap.Any("buildIndexParams", buildIndexParams)) - var err error - it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexParams) - if err != nil { - if it.index != nil && it.index.CleanLocalData() != nil { - log.Warn("failed to clean cached data on disk after build index failed") - } - log.Warn("failed to build index", zap.Error(err)) - return err - } - - buildIndexLatency := it.tr.RecordSpan() - metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(buildIndexLatency.Seconds()) - - log.Info("Successfully build index") - return nil -} - -func (it *indexBuildTask) SaveIndexFiles(ctx context.Context) error { - log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), - zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), - zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) - - gcIndex := func() { - if err := it.index.Delete(); err != nil { - log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err)) - } - } - indexFilePath2Size, err := it.index.UpLoad() - if err != nil { - log.Warn("failed to upload index", zap.Error(err)) - gcIndex() - return err - } - encodeIndexFileDur := it.tr.Record("index serialize and upload done") - metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds()) - - // early release index for gc, and we can ensure that Delete is idempotent. - gcIndex() - - // use serialized size before encoding - var serializedSize uint64 - saveFileKeys := make([]string, 0) - for filePath, fileSize := range indexFilePath2Size { - serializedSize += uint64(fileSize) - parts := strings.Split(filePath, "/") - fileKey := parts[len(parts)-1] - saveFileKeys = append(saveFileKeys, fileKey) - } - - it.node.storeIndexFilesAndStatistic(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion()) - log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys)) - saveIndexFileDur := it.tr.RecordSpan() - metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds()) - it.tr.Elapse("index building all done") - log.Info("Successfully save index files") - return nil -} - -func (it *indexBuildTask) parseFieldMetaFromBinlog(ctx context.Context) error { - // fill collectionID, partitionID... for requests before v2.4.0 - toLoadDataPaths := it.req.GetDataPaths() - if len(toLoadDataPaths) == 0 { - return merr.WrapErrParameterInvalidMsg("data insert path must be not empty") - } - data, err := it.cm.Read(ctx, toLoadDataPaths[0]) - if err != nil { - if errors.Is(err, merr.ErrIoKeyNotFound) { - return err - } - return err - } - - var insertCodec storage.InsertCodec - collectionID, partitionID, segmentID, insertData, err := insertCodec.DeserializeAll([]*Blob{{Key: toLoadDataPaths[0], Value: data}}) - if err != nil { - return err - } - if len(insertData.Data) != 1 { - return merr.WrapErrParameterInvalidMsg("we expect only one field in deserialized insert data") - } - - it.req.CollectionID = collectionID - it.req.PartitionID = partitionID - it.req.SegmentID = segmentID - if it.req.GetField().GetFieldID() == 0 { - for fID, value := range insertData.Data { - it.req.Field.DataType = value.GetDataType() - it.req.Field.FieldID = fID - break - } - } - it.req.CurrentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion()) - - return nil -} diff --git a/internal/indexnode/task_analyze.go b/internal/indexnode/task_analyze.go new file mode 100644 index 0000000000..3608ec3519 --- /dev/null +++ b/internal/indexnode/task_analyze.go @@ -0,0 +1,215 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexnode + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/clusteringpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/util/analyzecgowrapper" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/metautil" + "github.com/milvus-io/milvus/pkg/util/timerecord" +) + +type analyzeTask struct { + ident string + ctx context.Context + cancel context.CancelFunc + req *indexpb.AnalyzeRequest + + tr *timerecord.TimeRecorder + queueDur time.Duration + node *IndexNode + analyze analyzecgowrapper.CodecAnalyze + + startTime int64 + endTime int64 +} + +func (at *analyzeTask) Ctx() context.Context { + return at.ctx +} + +func (at *analyzeTask) Name() string { + return at.ident +} + +func (at *analyzeTask) PreExecute(ctx context.Context) error { + at.queueDur = at.tr.RecordSpan() + log := log.Ctx(ctx).With(zap.String("clusterID", at.req.GetClusterID()), + zap.Int64("taskID", at.req.GetTaskID()), zap.Int64("Collection", at.req.GetCollectionID()), + zap.Int64("partitionID", at.req.GetPartitionID()), zap.Int64("fieldID", at.req.GetFieldID())) + log.Info("Begin to prepare analyze task") + + log.Info("Successfully prepare analyze task, nothing to do...") + return nil +} + +func (at *analyzeTask) Execute(ctx context.Context) error { + var err error + + log := log.Ctx(ctx).With(zap.String("clusterID", at.req.GetClusterID()), + zap.Int64("taskID", at.req.GetTaskID()), zap.Int64("Collection", at.req.GetCollectionID()), + zap.Int64("partitionID", at.req.GetPartitionID()), zap.Int64("fieldID", at.req.GetFieldID())) + + log.Info("Begin to build analyze task") + if err != nil { + log.Warn("create analyze info failed", zap.Error(err)) + return err + } + + storageConfig := &clusteringpb.StorageConfig{ + Address: at.req.GetStorageConfig().GetAddress(), + AccessKeyID: at.req.GetStorageConfig().GetAccessKeyID(), + SecretAccessKey: at.req.GetStorageConfig().GetSecretAccessKey(), + UseSSL: at.req.GetStorageConfig().GetUseSSL(), + BucketName: at.req.GetStorageConfig().GetBucketName(), + RootPath: at.req.GetStorageConfig().GetRootPath(), + UseIAM: at.req.GetStorageConfig().GetUseIAM(), + IAMEndpoint: at.req.GetStorageConfig().GetIAMEndpoint(), + StorageType: at.req.GetStorageConfig().GetStorageType(), + UseVirtualHost: at.req.GetStorageConfig().GetUseVirtualHost(), + Region: at.req.GetStorageConfig().GetRegion(), + CloudProvider: at.req.GetStorageConfig().GetCloudProvider(), + RequestTimeoutMs: at.req.GetStorageConfig().GetRequestTimeoutMs(), + SslCACert: at.req.GetStorageConfig().GetSslCACert(), + } + + numRowsMap := make(map[int64]int64) + segmentInsertFilesMap := make(map[int64]*clusteringpb.InsertFiles) + + for segID, stats := range at.req.GetSegmentStats() { + numRows := stats.GetNumRows() + numRowsMap[segID] = numRows + log.Info("append segment rows", zap.Int64("segment id", segID), zap.Int64("rows", numRows)) + if err != nil { + log.Warn("append segment num rows failed", zap.Error(err)) + return err + } + insertFiles := make([]string, 0, len(stats.GetLogIDs())) + for _, id := range stats.GetLogIDs() { + path := metautil.BuildInsertLogPath(at.req.GetStorageConfig().RootPath, + at.req.GetCollectionID(), at.req.GetPartitionID(), segID, at.req.GetFieldID(), id) + insertFiles = append(insertFiles, path) + if err != nil { + log.Warn("append insert binlog path failed", zap.Error(err)) + return err + } + } + segmentInsertFilesMap[segID] = &clusteringpb.InsertFiles{InsertFiles: insertFiles} + } + + field := at.req.GetField() + if field == nil || field.GetDataType() == schemapb.DataType_None { + field = &schemapb.FieldSchema{ + FieldID: at.req.GetFieldID(), + Name: at.req.GetFieldName(), + DataType: at.req.GetFieldType(), + } + } + + analyzeInfo := &clusteringpb.AnalyzeInfo{ + ClusterID: at.req.GetClusterID(), + BuildID: at.req.GetTaskID(), + CollectionID: at.req.GetCollectionID(), + PartitionID: at.req.GetPartitionID(), + Version: at.req.GetVersion(), + Dim: at.req.GetDim(), + StorageConfig: storageConfig, + NumClusters: at.req.GetNumClusters(), + TrainSize: int64(float64(hardware.GetMemoryCount()) * at.req.GetMaxTrainSizeRatio()), + MinClusterRatio: at.req.GetMinClusterSizeRatio(), + MaxClusterRatio: at.req.GetMaxClusterSizeRatio(), + MaxClusterSize: at.req.GetMaxClusterSize(), + NumRows: numRowsMap, + InsertFiles: segmentInsertFilesMap, + FieldSchema: field, + } + + at.analyze, err = analyzecgowrapper.Analyze(ctx, analyzeInfo) + if err != nil { + log.Error("failed to analyze data", zap.Error(err)) + return err + } + + analyzeLatency := at.tr.RecordSpan() + log.Info("analyze done", zap.Int64("analyze cost", analyzeLatency.Milliseconds())) + return nil +} + +func (at *analyzeTask) PostExecute(ctx context.Context) error { + log := log.Ctx(ctx).With(zap.String("clusterID", at.req.GetClusterID()), + zap.Int64("taskID", at.req.GetTaskID()), zap.Int64("Collection", at.req.GetCollectionID()), + zap.Int64("partitionID", at.req.GetPartitionID()), zap.Int64("fieldID", at.req.GetFieldID())) + gc := func() { + if err := at.analyze.Delete(); err != nil { + log.Error("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err)) + } + } + defer gc() + + centroidsFile, _, _, _, err := at.analyze.GetResult(len(at.req.GetSegmentStats())) + if err != nil { + log.Error("failed to upload index", zap.Error(err)) + return err + } + log.Info("analyze result", zap.String("centroidsFile", centroidsFile)) + + at.endTime = time.Now().UnixMicro() + at.node.storeAnalyzeFilesAndStatistic(at.req.GetClusterID(), + at.req.GetTaskID(), + centroidsFile) + at.tr.Elapse("index building all done") + log.Info("Successfully save analyze files") + return nil +} + +func (at *analyzeTask) OnEnqueue(ctx context.Context) error { + at.queueDur = 0 + at.tr.RecordSpan() + at.startTime = time.Now().UnixMicro() + log.Ctx(ctx).Info("IndexNode analyzeTask enqueued", zap.String("clusterID", at.req.GetClusterID()), + zap.Int64("taskID", at.req.GetTaskID())) + return nil +} + +func (at *analyzeTask) SetState(state indexpb.JobState, failReason string) { + at.node.storeAnalyzeTaskState(at.req.GetClusterID(), at.req.GetTaskID(), state, failReason) +} + +func (at *analyzeTask) GetState() indexpb.JobState { + return at.node.loadAnalyzeTaskState(at.req.GetClusterID(), at.req.GetTaskID()) +} + +func (at *analyzeTask) Reset() { + at.ident = "" + at.ctx = nil + at.cancel = nil + at.req = nil + at.tr = nil + at.queueDur = 0 + at.node = nil + at.startTime = 0 + at.endTime = 0 +} diff --git a/internal/indexnode/task_index.go b/internal/indexnode/task_index.go new file mode 100644 index 0000000000..8c2791ac05 --- /dev/null +++ b/internal/indexnode/task_index.go @@ -0,0 +1,570 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexnode + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexcgopb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/indexcgowrapper" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" + "github.com/milvus-io/milvus/pkg/util/indexparams" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metautil" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" +) + +type indexBuildTaskV2 struct { + *indexBuildTask +} + +func newIndexBuildTaskV2(ctx context.Context, + cancel context.CancelFunc, + req *indexpb.CreateJobRequest, + node *IndexNode, +) *indexBuildTaskV2 { + t := &indexBuildTaskV2{ + indexBuildTask: &indexBuildTask{ + ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()), + cancel: cancel, + ctx: ctx, + req: req, + tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())), + node: node, + }, + } + + t.parseParams() + return t +} + +func (it *indexBuildTaskV2) parseParams() { + // fill field for requests before v2.5.0 + if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None { + it.req.Field = &schemapb.FieldSchema{ + FieldID: it.req.GetFieldID(), + Name: it.req.GetFieldName(), + DataType: it.req.GetFieldType(), + } + } +} + +func (it *indexBuildTaskV2) Execute(ctx context.Context) error { + log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), + zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), + zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) + + indexType := it.newIndexParams[common.IndexTypeKey] + if indexType == indexparamcheck.IndexDISKANN { + // check index node support disk index + if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { + log.Warn("IndexNode don't support build disk index", + zap.String("index type", it.newIndexParams[common.IndexTypeKey]), + zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool())) + return merr.WrapErrIndexNotSupported("disk index") + } + + // check load size and size of field data + localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue()) + if err != nil { + log.Warn("IndexNode get local used size failed") + return err + } + fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) + if err != nil { + log.Warn("IndexNode get local used size failed") + return err + } + usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize + maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat()) + + if usedLocalSizeWhenBuild > maxUsedLocalSize { + log.Warn("IndexNode don't has enough disk size to build disk ann index", + zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild), + zap.Int64("maxUsedLocalSize", maxUsedLocalSize)) + return merr.WrapErrServiceDiskLimitExceeded(float32(usedLocalSizeWhenBuild), float32(maxUsedLocalSize)) + } + + err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize)) + if err != nil { + log.Warn("failed to fill disk index params", zap.Error(err)) + return err + } + } + + storageConfig := &indexcgopb.StorageConfig{ + Address: it.req.GetStorageConfig().GetAddress(), + AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), + SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), + UseSSL: it.req.GetStorageConfig().GetUseSSL(), + BucketName: it.req.GetStorageConfig().GetBucketName(), + RootPath: it.req.GetStorageConfig().GetRootPath(), + UseIAM: it.req.GetStorageConfig().GetUseIAM(), + IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), + StorageType: it.req.GetStorageConfig().GetStorageType(), + UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), + Region: it.req.GetStorageConfig().GetRegion(), + CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), + RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), + SslCACert: it.req.GetStorageConfig().GetSslCACert(), + } + + optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) + for _, optField := range it.req.GetOptionalScalarFields() { + optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ + FieldID: optField.GetFieldID(), + FieldName: optField.GetFieldName(), + FieldType: optField.GetFieldType(), + DataPaths: optField.GetDataPaths(), + }) + } + + buildIndexParams := &indexcgopb.BuildIndexInfo{ + ClusterID: it.req.GetClusterID(), + BuildID: it.req.GetBuildID(), + CollectionID: it.req.GetCollectionID(), + PartitionID: it.req.GetPartitionID(), + SegmentID: it.req.GetSegmentID(), + IndexVersion: it.req.GetIndexVersion(), + CurrentIndexVersion: it.req.GetCurrentIndexVersion(), + NumRows: it.req.GetNumRows(), + Dim: it.req.GetDim(), + IndexFilePrefix: it.req.GetIndexFilePrefix(), + InsertFiles: it.req.GetDataPaths(), + FieldSchema: it.req.GetField(), + StorageConfig: storageConfig, + IndexParams: mapToKVPairs(it.newIndexParams), + TypeParams: mapToKVPairs(it.newTypeParams), + StorePath: it.req.GetStorePath(), + StoreVersion: it.req.GetStoreVersion(), + IndexStorePath: it.req.GetIndexStorePath(), + OptFields: optFields, + } + + var err error + it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexParams) + if err != nil { + if it.index != nil && it.index.CleanLocalData() != nil { + log.Warn("failed to clean cached data on disk after build index failed") + } + log.Warn("failed to build index", zap.Error(err)) + return err + } + + buildIndexLatency := it.tr.RecordSpan() + metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(buildIndexLatency.Milliseconds())) + + log.Info("Successfully build index") + return nil +} + +func (it *indexBuildTaskV2) PostExecute(ctx context.Context) error { + log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), + zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), + zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) + + gcIndex := func() { + if err := it.index.Delete(); err != nil { + log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err)) + } + } + version, err := it.index.UpLoadV2() + if err != nil { + log.Warn("failed to upload index", zap.Error(err)) + gcIndex() + return err + } + + encodeIndexFileDur := it.tr.Record("index serialize and upload done") + metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds()) + + // early release index for gc, and we can ensure that Delete is idempotent. + gcIndex() + + // use serialized size before encoding + var serializedSize uint64 + saveFileKeys := make([]string, 0) + + it.node.storeIndexFilesAndStatisticV2(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion(), version) + log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys)) + saveIndexFileDur := it.tr.RecordSpan() + metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds()) + it.tr.Elapse("index building all done") + log.Info("Successfully save index files") + return nil +} + +// IndexBuildTask is used to record the information of the index tasks. +type indexBuildTask struct { + ident string + cancel context.CancelFunc + ctx context.Context + + cm storage.ChunkManager + index indexcgowrapper.CodecIndex + req *indexpb.CreateJobRequest + newTypeParams map[string]string + newIndexParams map[string]string + tr *timerecord.TimeRecorder + queueDur time.Duration + node *IndexNode +} + +func newIndexBuildTask(ctx context.Context, + cancel context.CancelFunc, + req *indexpb.CreateJobRequest, + cm storage.ChunkManager, + node *IndexNode, +) *indexBuildTask { + t := &indexBuildTask{ + ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()), + cancel: cancel, + ctx: ctx, + cm: cm, + req: req, + tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())), + node: node, + } + + t.parseParams() + return t +} + +func (it *indexBuildTask) parseParams() { + // fill field for requests before v2.5.0 + if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None { + it.req.Field = &schemapb.FieldSchema{ + FieldID: it.req.GetFieldID(), + Name: it.req.GetFieldName(), + DataType: it.req.GetFieldType(), + } + } +} + +func (it *indexBuildTask) Reset() { + it.ident = "" + it.cancel = nil + it.ctx = nil + it.cm = nil + it.index = nil + it.req = nil + it.newTypeParams = nil + it.newIndexParams = nil + it.tr = nil + it.node = nil +} + +// Ctx is the context of index tasks. +func (it *indexBuildTask) Ctx() context.Context { + return it.ctx +} + +// Name is the name of task to build index. +func (it *indexBuildTask) Name() string { + return it.ident +} + +func (it *indexBuildTask) SetState(state indexpb.JobState, failReason string) { + it.node.storeIndexTaskState(it.req.GetClusterID(), it.req.GetBuildID(), commonpb.IndexState(state), failReason) +} + +func (it *indexBuildTask) GetState() indexpb.JobState { + return indexpb.JobState(it.node.loadIndexTaskState(it.req.GetClusterID(), it.req.GetBuildID())) +} + +// OnEnqueue enqueues indexing tasks. +func (it *indexBuildTask) OnEnqueue(ctx context.Context) error { + it.queueDur = 0 + it.tr.RecordSpan() + log.Ctx(ctx).Info("IndexNode IndexBuilderTask Enqueue", zap.Int64("buildID", it.req.GetBuildID()), + zap.Int64("segmentID", it.req.GetSegmentID())) + return nil +} + +func (it *indexBuildTask) PreExecute(ctx context.Context) error { + it.queueDur = it.tr.RecordSpan() + log.Ctx(ctx).Info("Begin to prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()), + zap.Int64("Collection", it.req.GetCollectionID()), zap.Int64("SegmentID", it.req.GetSegmentID())) + + typeParams := make(map[string]string) + indexParams := make(map[string]string) + + if len(it.req.DataPaths) == 0 { + for _, id := range it.req.GetDataIds() { + path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetField().GetFieldID(), id) + it.req.DataPaths = append(it.req.DataPaths, path) + } + } + + if it.req.OptionalScalarFields != nil { + for _, optFields := range it.req.GetOptionalScalarFields() { + if len(optFields.DataPaths) == 0 { + for _, id := range optFields.DataIds { + path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), optFields.FieldID, id) + optFields.DataPaths = append(optFields.DataPaths, path) + } + } + } + } + + // type params can be removed + for _, kvPair := range it.req.GetTypeParams() { + key, value := kvPair.GetKey(), kvPair.GetValue() + typeParams[key] = value + indexParams[key] = value + } + + for _, kvPair := range it.req.GetIndexParams() { + key, value := kvPair.GetKey(), kvPair.GetValue() + // knowhere would report error if encountered the unknown key, + // so skip this + if key == common.MmapEnabledKey { + continue + } + indexParams[key] = value + } + it.newTypeParams = typeParams + it.newIndexParams = indexParams + + if it.req.GetDim() == 0 { + // fill dim for requests before v2.4.0 + if dimStr, ok := typeParams[common.DimKey]; ok { + var err error + it.req.Dim, err = strconv.ParseInt(dimStr, 10, 64) + if err != nil { + log.Ctx(ctx).Error("parse dimesion failed", zap.Error(err)) + // ignore error + } + } + } + + if it.req.GetCollectionID() == 0 { + err := it.parseFieldMetaFromBinlog(ctx) + if err != nil { + log.Ctx(ctx).Warn("parse field meta from binlog failed", zap.Error(err)) + return err + } + } + + log.Ctx(ctx).Info("Successfully prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()), + zap.Int64("collectionID", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID())) + return nil +} + +func (it *indexBuildTask) Execute(ctx context.Context) error { + log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), + zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), + zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) + + indexType := it.newIndexParams[common.IndexTypeKey] + if indexType == indexparamcheck.IndexDISKANN { + // check index node support disk index + if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { + log.Warn("IndexNode don't support build disk index", + zap.String("index type", it.newIndexParams[common.IndexTypeKey]), + zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool())) + return errors.New("index node don't support build disk index") + } + + // check load size and size of field data + localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue()) + if err != nil { + log.Warn("IndexNode get local used size failed") + return err + } + fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) + if err != nil { + log.Warn("IndexNode get local used size failed") + return err + } + usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize + maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat()) + + if usedLocalSizeWhenBuild > maxUsedLocalSize { + log.Warn("IndexNode don't has enough disk size to build disk ann index", + zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild), + zap.Int64("maxUsedLocalSize", maxUsedLocalSize)) + return errors.New("index node don't has enough disk size to build disk ann index") + } + + err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize)) + if err != nil { + log.Warn("failed to fill disk index params", zap.Error(err)) + return err + } + } + + storageConfig := &indexcgopb.StorageConfig{ + Address: it.req.GetStorageConfig().GetAddress(), + AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), + SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), + UseSSL: it.req.GetStorageConfig().GetUseSSL(), + BucketName: it.req.GetStorageConfig().GetBucketName(), + RootPath: it.req.GetStorageConfig().GetRootPath(), + UseIAM: it.req.GetStorageConfig().GetUseIAM(), + IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), + StorageType: it.req.GetStorageConfig().GetStorageType(), + UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), + Region: it.req.GetStorageConfig().GetRegion(), + CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), + RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), + SslCACert: it.req.GetStorageConfig().GetSslCACert(), + } + + optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) + for _, optField := range it.req.GetOptionalScalarFields() { + optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ + FieldID: optField.GetFieldID(), + FieldName: optField.GetFieldName(), + FieldType: optField.GetFieldType(), + DataPaths: optField.GetDataPaths(), + }) + } + + buildIndexParams := &indexcgopb.BuildIndexInfo{ + ClusterID: it.req.GetClusterID(), + BuildID: it.req.GetBuildID(), + CollectionID: it.req.GetCollectionID(), + PartitionID: it.req.GetPartitionID(), + SegmentID: it.req.GetSegmentID(), + IndexVersion: it.req.GetIndexVersion(), + CurrentIndexVersion: it.req.GetCurrentIndexVersion(), + NumRows: it.req.GetNumRows(), + Dim: it.req.GetDim(), + IndexFilePrefix: it.req.GetIndexFilePrefix(), + InsertFiles: it.req.GetDataPaths(), + FieldSchema: it.req.GetField(), + StorageConfig: storageConfig, + IndexParams: mapToKVPairs(it.newIndexParams), + TypeParams: mapToKVPairs(it.newTypeParams), + StorePath: it.req.GetStorePath(), + StoreVersion: it.req.GetStoreVersion(), + IndexStorePath: it.req.GetIndexStorePath(), + OptFields: optFields, + } + + log.Info("debug create index", zap.Any("buildIndexParams", buildIndexParams)) + var err error + it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexParams) + if err != nil { + if it.index != nil && it.index.CleanLocalData() != nil { + log.Warn("failed to clean cached data on disk after build index failed") + } + log.Warn("failed to build index", zap.Error(err)) + return err + } + + buildIndexLatency := it.tr.RecordSpan() + metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(buildIndexLatency.Seconds()) + + log.Info("Successfully build index") + return nil +} + +func (it *indexBuildTask) PostExecute(ctx context.Context) error { + log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()), + zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()), + zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) + + gcIndex := func() { + if err := it.index.Delete(); err != nil { + log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err)) + } + } + indexFilePath2Size, err := it.index.UpLoad() + if err != nil { + log.Warn("failed to upload index", zap.Error(err)) + gcIndex() + return err + } + encodeIndexFileDur := it.tr.Record("index serialize and upload done") + metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds()) + + // early release index for gc, and we can ensure that Delete is idempotent. + gcIndex() + + // use serialized size before encoding + var serializedSize uint64 + saveFileKeys := make([]string, 0) + for filePath, fileSize := range indexFilePath2Size { + serializedSize += uint64(fileSize) + parts := strings.Split(filePath, "/") + fileKey := parts[len(parts)-1] + saveFileKeys = append(saveFileKeys, fileKey) + } + + it.node.storeIndexFilesAndStatistic(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion()) + log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys)) + saveIndexFileDur := it.tr.RecordSpan() + metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds()) + it.tr.Elapse("index building all done") + log.Info("Successfully save index files") + return nil +} + +func (it *indexBuildTask) parseFieldMetaFromBinlog(ctx context.Context) error { + // fill collectionID, partitionID... for requests before v2.4.0 + toLoadDataPaths := it.req.GetDataPaths() + if len(toLoadDataPaths) == 0 { + return merr.WrapErrParameterInvalidMsg("data insert path must be not empty") + } + data, err := it.cm.Read(ctx, toLoadDataPaths[0]) + if err != nil { + if errors.Is(err, merr.ErrIoKeyNotFound) { + return err + } + return err + } + + var insertCodec storage.InsertCodec + collectionID, partitionID, segmentID, insertData, err := insertCodec.DeserializeAll([]*Blob{{Key: toLoadDataPaths[0], Value: data}}) + if err != nil { + return err + } + if len(insertData.Data) != 1 { + return merr.WrapErrParameterInvalidMsg("we expect only one field in deserialized insert data") + } + + it.req.CollectionID = collectionID + it.req.PartitionID = partitionID + it.req.SegmentID = segmentID + if it.req.GetField().GetFieldID() == 0 { + for fID, value := range insertData.Data { + it.req.Field.DataType = value.GetDataType() + it.req.Field.FieldID = fID + break + } + } + it.req.CurrentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion()) + + return nil +} diff --git a/internal/indexnode/task_scheduler.go b/internal/indexnode/task_scheduler.go index 539a887480..3f5c986149 100644 --- a/internal/indexnode/task_scheduler.go +++ b/internal/indexnode/task_scheduler.go @@ -26,7 +26,7 @@ import ( "github.com/cockroachdb/errors" "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" @@ -147,7 +147,7 @@ func (queue *IndexTaskQueue) GetTaskNum() (int, int) { atNum := 0 // remove the finished task for _, task := range queue.activeTasks { - if task.GetState() != commonpb.IndexState_Finished && task.GetState() != commonpb.IndexState_Failed { + if task.GetState() != indexpb.JobState_JobStateFinished && task.GetState() != indexpb.JobState_JobStateFailed { atNum++ } } @@ -168,7 +168,7 @@ func NewIndexBuildTaskQueue(sched *TaskScheduler) *IndexTaskQueue { // TaskScheduler is a scheduler of indexing tasks. type TaskScheduler struct { - IndexBuildQueue TaskQueue + TaskQueue TaskQueue buildParallel int wg sync.WaitGroup @@ -184,7 +184,7 @@ func NewTaskScheduler(ctx context.Context) *TaskScheduler { cancel: cancel, buildParallel: Params.IndexNodeCfg.BuildParallel.GetAsInt(), } - s.IndexBuildQueue = NewIndexBuildTaskQueue(s) + s.TaskQueue = NewIndexBuildTaskQueue(s) return s } @@ -192,7 +192,7 @@ func NewTaskScheduler(ctx context.Context) *TaskScheduler { func (sched *TaskScheduler) scheduleIndexBuildTask() []task { ret := make([]task, 0) for i := 0; i < sched.buildParallel; i++ { - t := sched.IndexBuildQueue.PopUnissuedTask() + t := sched.TaskQueue.PopUnissuedTask() if t == nil { return ret } @@ -201,14 +201,16 @@ func (sched *TaskScheduler) scheduleIndexBuildTask() []task { return ret } -func getStateFromError(err error) commonpb.IndexState { +func getStateFromError(err error) indexpb.JobState { if errors.Is(err, errCancel) { - return commonpb.IndexState_Retry + return indexpb.JobState_JobStateRetry } else if errors.Is(err, merr.ErrIoKeyNotFound) || errors.Is(err, merr.ErrSegcoreUnsupported) { // NoSuchKey or unsupported error - return commonpb.IndexState_Failed + return indexpb.JobState_JobStateFailed + } else if errors.Is(err, merr.ErrSegcorePretendFinished) { + return indexpb.JobState_JobStateFinished } - return commonpb.IndexState_Retry + return indexpb.JobState_JobStateRetry } func (sched *TaskScheduler) processTask(t task, q TaskQueue) { @@ -225,10 +227,10 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) { t.Reset() debug.FreeOSMemory() }() - sched.IndexBuildQueue.AddActiveTask(t) - defer sched.IndexBuildQueue.PopActiveTask(t.Name()) + sched.TaskQueue.AddActiveTask(t) + defer sched.TaskQueue.PopActiveTask(t.Name()) log.Ctx(t.Ctx()).Debug("process task", zap.String("task", t.Name())) - pipelines := []func(context.Context) error{t.Prepare, t.BuildIndex, t.SaveIndexFiles} + pipelines := []func(context.Context) error{t.PreExecute, t.Execute, t.PostExecute} for _, fn := range pipelines { if err := wrap(fn); err != nil { log.Ctx(t.Ctx()).Warn("process task failed", zap.Error(err)) @@ -236,7 +238,7 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) { return } } - t.SetState(commonpb.IndexState_Finished, "") + t.SetState(indexpb.JobState_JobStateFinished, "") if indexBuildTask, ok := t.(*indexBuildTask); ok { metrics.IndexNodeBuildIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(indexBuildTask.tr.ElapseSpan().Seconds()) metrics.IndexNodeIndexTaskLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(indexBuildTask.queueDur.Milliseconds())) @@ -250,14 +252,14 @@ func (sched *TaskScheduler) indexBuildLoop() { select { case <-sched.ctx.Done(): return - case <-sched.IndexBuildQueue.utChan(): + case <-sched.TaskQueue.utChan(): tasks := sched.scheduleIndexBuildTask() var wg sync.WaitGroup for _, t := range tasks { wg.Add(1) go func(group *sync.WaitGroup, t task) { defer group.Done() - sched.processTask(t, sched.IndexBuildQueue) + sched.processTask(t, sched.TaskQueue) }(&wg, t) } wg.Wait() diff --git a/internal/indexnode/task_scheduler_test.go b/internal/indexnode/task_scheduler_test.go index 2393fd2b7e..36e5b04db3 100644 --- a/internal/indexnode/task_scheduler_test.go +++ b/internal/indexnode/task_scheduler_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -72,8 +72,8 @@ type fakeTask struct { ctx context.Context state fakeTaskState reterr map[fakeTaskState]error - retstate commonpb.IndexState - expectedState commonpb.IndexState + retstate indexpb.JobState + expectedState indexpb.JobState failReason string } @@ -94,7 +94,7 @@ func (t *fakeTask) OnEnqueue(ctx context.Context) error { return t.reterr[t.state] } -func (t *fakeTask) Prepare(ctx context.Context) error { +func (t *fakeTask) PreExecute(ctx context.Context) error { t.state = fakeTaskPrepared t.ctx.(*stagectx).setState(t.state) return t.reterr[t.state] @@ -106,13 +106,13 @@ func (t *fakeTask) LoadData(ctx context.Context) error { return t.reterr[t.state] } -func (t *fakeTask) BuildIndex(ctx context.Context) error { +func (t *fakeTask) Execute(ctx context.Context) error { t.state = fakeTaskBuiltIndex t.ctx.(*stagectx).setState(t.state) return t.reterr[t.state] } -func (t *fakeTask) SaveIndexFiles(ctx context.Context) error { +func (t *fakeTask) PostExecute(ctx context.Context) error { t.state = fakeTaskSavedIndexes t.ctx.(*stagectx).setState(t.state) return t.reterr[t.state] @@ -122,12 +122,12 @@ func (t *fakeTask) Reset() { _taskwg.Done() } -func (t *fakeTask) SetState(state commonpb.IndexState, failReason string) { +func (t *fakeTask) SetState(state indexpb.JobState, failReason string) { t.retstate = state t.failReason = failReason } -func (t *fakeTask) GetState() commonpb.IndexState { +func (t *fakeTask) GetState() indexpb.JobState { return t.retstate } @@ -136,7 +136,7 @@ var ( id = 0 ) -func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expectedState commonpb.IndexState) task { +func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expectedState indexpb.JobState) task { idLock.Lock() newID := id id++ @@ -151,7 +151,7 @@ func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expect ch: make(chan struct{}), }, state: fakeTaskInited, - retstate: commonpb.IndexState_IndexStateNone, + retstate: indexpb.JobState_JobStateNone, expectedState: expectedState, } } @@ -165,14 +165,14 @@ func TestIndexTaskScheduler(t *testing.T) { tasks := make([]task, 0) tasks = append(tasks, - newTask(fakeTaskEnqueued, nil, commonpb.IndexState_Retry), - newTask(fakeTaskPrepared, nil, commonpb.IndexState_Retry), - newTask(fakeTaskBuiltIndex, nil, commonpb.IndexState_Retry), - newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished), - newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskSavedIndexes: fmt.Errorf("auth failed")}, commonpb.IndexState_Retry)) + newTask(fakeTaskEnqueued, nil, indexpb.JobState_JobStateRetry), + newTask(fakeTaskPrepared, nil, indexpb.JobState_JobStateRetry), + newTask(fakeTaskBuiltIndex, nil, indexpb.JobState_JobStateRetry), + newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished), + newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskSavedIndexes: fmt.Errorf("auth failed")}, indexpb.JobState_JobStateRetry)) for _, task := range tasks { - assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(task)) + assert.Nil(t, scheduler.TaskQueue.Enqueue(task)) } _taskwg.Wait() scheduler.Close() @@ -189,11 +189,11 @@ func TestIndexTaskScheduler(t *testing.T) { scheduler = NewTaskScheduler(context.TODO()) tasks = make([]task, 0, 1024) for i := 0; i < 1024; i++ { - tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished)) - assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(tasks[len(tasks)-1])) + tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished)) + assert.Nil(t, scheduler.TaskQueue.Enqueue(tasks[len(tasks)-1])) } - failTask := newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished) - err := scheduler.IndexBuildQueue.Enqueue(failTask) + failTask := newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished) + err := scheduler.TaskQueue.Enqueue(failTask) assert.Error(t, err) failTask.Reset() @@ -202,6 +202,6 @@ func TestIndexTaskScheduler(t *testing.T) { scheduler.Close() scheduler.wg.Wait() for _, task := range tasks { - assert.Equal(t, task.GetState(), commonpb.IndexState_Finished) + assert.Equal(t, task.GetState(), indexpb.JobState_JobStateFinished) } } diff --git a/internal/indexnode/task_test.go b/internal/indexnode/task_test.go index 530dcdadac..28de64275f 100644 --- a/internal/indexnode/task_test.go +++ b/internal/indexnode/task_test.go @@ -30,14 +30,115 @@ import ( milvus_storage "github.com/milvus-io/milvus-storage/go/storage" "github.com/milvus-io/milvus-storage/go/storage/options" "github.com/milvus-io/milvus-storage/go/storage/schema" + "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" ) +type IndexBuildTaskSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + collectionID int64 + partitionID int64 + segmentID int64 + dataPath string + + numRows int + dim int +} + +func (suite *IndexBuildTaskSuite) SetupSuite() { + paramtable.Init() + suite.collectionID = 1000 + suite.partitionID = 1001 + suite.segmentID = 1002 + suite.dataPath = "/tmp/milvus/data/1000/1001/1002/3/1" + suite.numRows = 100 + suite.dim = 128 +} + +func (suite *IndexBuildTaskSuite) SetupTest() { + suite.schema = &schemapb.CollectionSchema{ + Name: "test", + Description: "test", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + {FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "ts", DataType: schemapb.DataType_Int64}, + {FieldID: 102, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + }, + } +} + +func (suite *IndexBuildTaskSuite) serializeData() ([]*storage.Blob, error) { + insertCodec := storage.NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ + Schema: suite.schema, + }) + return insertCodec.Serialize(suite.partitionID, suite.segmentID, &storage.InsertData{ + Data: map[storage.FieldID]storage.FieldData{ + 0: &storage.Int64FieldData{Data: generateLongs(suite.numRows)}, + 1: &storage.Int64FieldData{Data: generateLongs(suite.numRows)}, + 100: &storage.Int64FieldData{Data: generateLongs(suite.numRows)}, + 101: &storage.Int64FieldData{Data: generateLongs(suite.numRows)}, + 102: &storage.FloatVectorFieldData{Data: generateFloats(suite.numRows * suite.dim), Dim: suite.dim}, + }, + Infos: []storage.BlobInfo{{Length: suite.numRows}}, + }) +} + +func (suite *IndexBuildTaskSuite) TestBuildMemoryIndex() { + ctx, cancel := context.WithCancel(context.Background()) + req := &indexpb.CreateJobRequest{ + BuildID: 1, + IndexVersion: 1, + DataPaths: []string{suite.dataPath}, + IndexID: 0, + IndexName: "", + IndexParams: []*commonpb.KeyValuePair{{Key: common.IndexTypeKey, Value: "FLAT"}, {Key: common.MetricTypeKey, Value: metric.L2}}, + TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}, + NumRows: int64(suite.numRows), + StorageConfig: &indexpb.StorageConfig{ + RootPath: "/tmp/milvus/data", + StorageType: "local", + }, + CollectionID: 1, + PartitionID: 2, + SegmentID: 3, + FieldID: 102, + FieldName: "vec", + FieldType: schemapb.DataType_FloatVector, + } + + cm, err := NewChunkMgrFactory().NewChunkManager(ctx, req.GetStorageConfig()) + suite.NoError(err) + blobs, err := suite.serializeData() + suite.NoError(err) + err = cm.Write(ctx, suite.dataPath, blobs[0].Value) + suite.NoError(err) + + t := newIndexBuildTask(ctx, cancel, req, cm, NewIndexNode(context.Background(), dependency.NewDefaultFactory(true))) + + err = t.PreExecute(context.Background()) + suite.NoError(err) + err = t.Execute(context.Background()) + suite.NoError(err) + err = t.PostExecute(context.Background()) + suite.NoError(err) +} + +func TestIndexBuildTask(t *testing.T) { + suite.Run(t, new(IndexBuildTaskSuite)) +} + type IndexBuildTaskV2Suite struct { suite.Suite schema *schemapb.CollectionSchema @@ -125,14 +226,117 @@ func (suite *IndexBuildTaskV2Suite) TestBuildIndex() { task := newIndexBuildTaskV2(context.Background(), nil, req, NewIndexNode(context.Background(), dependency.NewDefaultFactory(true))) var err error - err = task.Prepare(context.Background()) + err = task.PreExecute(context.Background()) suite.NoError(err) - err = task.BuildIndex(context.Background()) + err = task.Execute(context.Background()) suite.NoError(err) - err = task.SaveIndexFiles(context.Background()) + err = task.PostExecute(context.Background()) suite.NoError(err) } func TestIndexBuildTaskV2Suite(t *testing.T) { suite.Run(t, new(IndexBuildTaskV2Suite)) } + +type AnalyzeTaskSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + collectionID int64 + partitionID int64 + segmentID int64 + fieldID int64 + taskID int64 +} + +func (suite *AnalyzeTaskSuite) SetupSuite() { + paramtable.Init() + suite.collectionID = 1000 + suite.partitionID = 1001 + suite.segmentID = 1002 + suite.fieldID = 102 + suite.taskID = 1004 +} + +func (suite *AnalyzeTaskSuite) SetupTest() { + suite.schema = &schemapb.CollectionSchema{ + Name: "test", + Description: "test", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + {FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "ts", DataType: schemapb.DataType_Int64}, + {FieldID: 102, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "1"}}}, + }, + } +} + +func (suite *AnalyzeTaskSuite) serializeData() ([]*storage.Blob, error) { + insertCodec := storage.NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ + Schema: suite.schema, + }) + return insertCodec.Serialize(suite.partitionID, suite.segmentID, &storage.InsertData{ + Data: map[storage.FieldID]storage.FieldData{ + 0: &storage.Int64FieldData{Data: []int64{0, 1, 2}}, + 1: &storage.Int64FieldData{Data: []int64{1, 2, 3}}, + 100: &storage.Int64FieldData{Data: []int64{0, 1, 2}}, + 101: &storage.Int64FieldData{Data: []int64{0, 1, 2}}, + 102: &storage.FloatVectorFieldData{Data: []float32{1, 2, 3}, Dim: 1}, + }, + Infos: []storage.BlobInfo{{Length: 3}}, + }) +} + +func (suite *AnalyzeTaskSuite) TestAnalyze() { + ctx, cancel := context.WithCancel(context.Background()) + req := &indexpb.AnalyzeRequest{ + ClusterID: "test", + TaskID: 1, + CollectionID: suite.collectionID, + PartitionID: suite.partitionID, + FieldID: suite.fieldID, + FieldName: "vec", + FieldType: schemapb.DataType_FloatVector, + SegmentStats: map[int64]*indexpb.SegmentStats{ + suite.segmentID: { + ID: suite.segmentID, + NumRows: 1024, + LogIDs: []int64{1}, + }, + }, + Version: 1, + StorageConfig: &indexpb.StorageConfig{ + RootPath: "/tmp/milvus/data", + StorageType: "local", + }, + Dim: 1, + } + + cm, err := NewChunkMgrFactory().NewChunkManager(ctx, req.GetStorageConfig()) + suite.NoError(err) + blobs, err := suite.serializeData() + suite.NoError(err) + dataPath := metautil.BuildInsertLogPath(cm.RootPath(), suite.collectionID, suite.partitionID, suite.segmentID, + suite.fieldID, 1) + + err = cm.Write(ctx, dataPath, blobs[0].Value) + suite.NoError(err) + + t := &analyzeTask{ + ident: "", + cancel: cancel, + ctx: ctx, + req: req, + tr: timerecord.NewTimeRecorder("test-indexBuildTask"), + queueDur: 0, + node: NewIndexNode(context.Background(), dependency.NewDefaultFactory(true)), + } + + err = t.PreExecute(context.Background()) + suite.NoError(err) +} + +func TestAnalyzeTaskSuite(t *testing.T) { + suite.Run(t, new(AnalyzeTaskSuite)) +} diff --git a/internal/indexnode/taskinfo_ops.go b/internal/indexnode/taskinfo_ops.go index 957fab23d7..be9ea957da 100644 --- a/internal/indexnode/taskinfo_ops.go +++ b/internal/indexnode/taskinfo_ops.go @@ -7,38 +7,52 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" ) -func (i *IndexNode) loadOrStoreTask(ClusterID string, buildID UniqueID, info *taskInfo) *taskInfo { +type indexTaskInfo struct { + cancel context.CancelFunc + state commonpb.IndexState + fileKeys []string + serializedSize uint64 + failReason string + currentIndexVersion int32 + indexStoreVersion int64 + + // task statistics + statistic *indexpb.JobInfo +} + +func (i *IndexNode) loadOrStoreIndexTask(ClusterID string, buildID UniqueID, info *indexTaskInfo) *indexTaskInfo { i.stateLock.Lock() defer i.stateLock.Unlock() key := taskKey{ClusterID: ClusterID, BuildID: buildID} - oldInfo, ok := i.tasks[key] + oldInfo, ok := i.indexTasks[key] if ok { return oldInfo } - i.tasks[key] = info + i.indexTasks[key] = info return nil } -func (i *IndexNode) loadTaskState(ClusterID string, buildID UniqueID) commonpb.IndexState { +func (i *IndexNode) loadIndexTaskState(ClusterID string, buildID UniqueID) commonpb.IndexState { key := taskKey{ClusterID: ClusterID, BuildID: buildID} i.stateLock.Lock() defer i.stateLock.Unlock() - task, ok := i.tasks[key] + task, ok := i.indexTasks[key] if !ok { return commonpb.IndexState_IndexStateNone } return task.state } -func (i *IndexNode) storeTaskState(ClusterID string, buildID UniqueID, state commonpb.IndexState, failReason string) { +func (i *IndexNode) storeIndexTaskState(ClusterID string, buildID UniqueID, state commonpb.IndexState, failReason string) { key := taskKey{ClusterID: ClusterID, BuildID: buildID} i.stateLock.Lock() defer i.stateLock.Unlock() - if task, ok := i.tasks[key]; ok { + if task, ok := i.indexTasks[key]; ok { log.Debug("IndexNode store task state", zap.String("clusterID", ClusterID), zap.Int64("buildID", buildID), zap.String("state", state.String()), zap.String("fail reason", failReason)) task.state = state @@ -46,10 +60,10 @@ func (i *IndexNode) storeTaskState(ClusterID string, buildID UniqueID, state com } } -func (i *IndexNode) foreachTaskInfo(fn func(ClusterID string, buildID UniqueID, info *taskInfo)) { +func (i *IndexNode) foreachIndexTaskInfo(fn func(ClusterID string, buildID UniqueID, info *indexTaskInfo)) { i.stateLock.Lock() defer i.stateLock.Unlock() - for key, info := range i.tasks { + for key, info := range i.indexTasks { fn(key.ClusterID, key.BuildID, info) } } @@ -64,7 +78,7 @@ func (i *IndexNode) storeIndexFilesAndStatistic( key := taskKey{ClusterID: ClusterID, BuildID: buildID} i.stateLock.Lock() defer i.stateLock.Unlock() - if info, ok := i.tasks[key]; ok { + if info, ok := i.indexTasks[key]; ok { info.fileKeys = common.CloneStringList(fileKeys) info.serializedSize = serializedSize info.currentIndexVersion = currentIndexVersion @@ -83,7 +97,7 @@ func (i *IndexNode) storeIndexFilesAndStatisticV2( key := taskKey{ClusterID: ClusterID, BuildID: buildID} i.stateLock.Lock() defer i.stateLock.Unlock() - if info, ok := i.tasks[key]; ok { + if info, ok := i.indexTasks[key]; ok { info.fileKeys = common.CloneStringList(fileKeys) info.serializedSize = serializedSize info.currentIndexVersion = currentIndexVersion @@ -92,15 +106,15 @@ func (i *IndexNode) storeIndexFilesAndStatisticV2( } } -func (i *IndexNode) deleteTaskInfos(ctx context.Context, keys []taskKey) []*taskInfo { +func (i *IndexNode) deleteIndexTaskInfos(ctx context.Context, keys []taskKey) []*indexTaskInfo { i.stateLock.Lock() defer i.stateLock.Unlock() - deleted := make([]*taskInfo, 0, len(keys)) + deleted := make([]*indexTaskInfo, 0, len(keys)) for _, key := range keys { - info, ok := i.tasks[key] + info, ok := i.indexTasks[key] if ok { deleted = append(deleted, info) - delete(i.tasks, key) + delete(i.indexTasks, key) log.Ctx(ctx).Info("delete task infos", zap.String("cluster_id", key.ClusterID), zap.Int64("build_id", key.BuildID)) } @@ -108,13 +122,113 @@ func (i *IndexNode) deleteTaskInfos(ctx context.Context, keys []taskKey) []*task return deleted } -func (i *IndexNode) deleteAllTasks() []*taskInfo { +func (i *IndexNode) deleteAllIndexTasks() []*indexTaskInfo { i.stateLock.Lock() - deletedTasks := i.tasks - i.tasks = make(map[taskKey]*taskInfo) + deletedTasks := i.indexTasks + i.indexTasks = make(map[taskKey]*indexTaskInfo) i.stateLock.Unlock() - deleted := make([]*taskInfo, 0, len(deletedTasks)) + deleted := make([]*indexTaskInfo, 0, len(deletedTasks)) + for _, info := range deletedTasks { + deleted = append(deleted, info) + } + return deleted +} + +type analyzeTaskInfo struct { + cancel context.CancelFunc + state indexpb.JobState + failReason string + centroidsFile string +} + +func (i *IndexNode) loadOrStoreAnalyzeTask(clusterID string, taskID UniqueID, info *analyzeTaskInfo) *analyzeTaskInfo { + i.stateLock.Lock() + defer i.stateLock.Unlock() + key := taskKey{ClusterID: clusterID, BuildID: taskID} + oldInfo, ok := i.analyzeTasks[key] + if ok { + return oldInfo + } + i.analyzeTasks[key] = info + return nil +} + +func (i *IndexNode) loadAnalyzeTaskState(clusterID string, taskID UniqueID) indexpb.JobState { + key := taskKey{ClusterID: clusterID, BuildID: taskID} + i.stateLock.Lock() + defer i.stateLock.Unlock() + task, ok := i.analyzeTasks[key] + if !ok { + return indexpb.JobState_JobStateNone + } + return task.state +} + +func (i *IndexNode) storeAnalyzeTaskState(clusterID string, taskID UniqueID, state indexpb.JobState, failReason string) { + key := taskKey{ClusterID: clusterID, BuildID: taskID} + i.stateLock.Lock() + defer i.stateLock.Unlock() + if task, ok := i.analyzeTasks[key]; ok { + log.Info("IndexNode store analyze task state", zap.String("clusterID", clusterID), zap.Int64("taskID", taskID), + zap.String("state", state.String()), zap.String("fail reason", failReason)) + task.state = state + task.failReason = failReason + } +} + +func (i *IndexNode) foreachAnalyzeTaskInfo(fn func(clusterID string, taskID UniqueID, info *analyzeTaskInfo)) { + i.stateLock.Lock() + defer i.stateLock.Unlock() + for key, info := range i.analyzeTasks { + fn(key.ClusterID, key.BuildID, info) + } +} + +func (i *IndexNode) storeAnalyzeFilesAndStatistic( + ClusterID string, + taskID UniqueID, + centroidsFile string, +) { + key := taskKey{ClusterID: ClusterID, BuildID: taskID} + i.stateLock.Lock() + defer i.stateLock.Unlock() + if info, ok := i.analyzeTasks[key]; ok { + info.centroidsFile = centroidsFile + return + } +} + +func (i *IndexNode) getAnalyzeTaskInfo(clusterID string, taskID UniqueID) *analyzeTaskInfo { + i.stateLock.Lock() + defer i.stateLock.Unlock() + + return i.analyzeTasks[taskKey{ClusterID: clusterID, BuildID: taskID}] +} + +func (i *IndexNode) deleteAnalyzeTaskInfos(ctx context.Context, keys []taskKey) []*analyzeTaskInfo { + i.stateLock.Lock() + defer i.stateLock.Unlock() + deleted := make([]*analyzeTaskInfo, 0, len(keys)) + for _, key := range keys { + info, ok := i.analyzeTasks[key] + if ok { + deleted = append(deleted, info) + delete(i.analyzeTasks, key) + log.Ctx(ctx).Info("delete analyze task infos", + zap.String("clusterID", key.ClusterID), zap.Int64("taskID", key.BuildID)) + } + } + return deleted +} + +func (i *IndexNode) deleteAllAnalyzeTasks() []*analyzeTaskInfo { + i.stateLock.Lock() + deletedTasks := i.analyzeTasks + i.analyzeTasks = make(map[taskKey]*analyzeTaskInfo) + i.stateLock.Unlock() + + deleted := make([]*analyzeTaskInfo, 0, len(deletedTasks)) for _, info := range deletedTasks { deleted = append(deleted, info) } @@ -124,11 +238,17 @@ func (i *IndexNode) deleteAllTasks() []*taskInfo { func (i *IndexNode) hasInProgressTask() bool { i.stateLock.Lock() defer i.stateLock.Unlock() - for _, info := range i.tasks { + for _, info := range i.indexTasks { if info.state == commonpb.IndexState_InProgress { return true } } + + for _, info := range i.analyzeTasks { + if info.state == indexpb.JobState_JobStateInProgress { + return true + } + } return false } @@ -151,11 +271,16 @@ func (i *IndexNode) waitTaskFinish() { } case <-timeoutCtx.Done(): log.Warn("timeout, the index node has some progress task") - for _, info := range i.tasks { + for _, info := range i.indexTasks { if info.state == commonpb.IndexState_InProgress { log.Warn("progress task", zap.Any("info", info)) } } + for _, info := range i.analyzeTasks { + if info.state == indexpb.JobState_JobStateInProgress { + log.Warn("progress task", zap.Any("info", info)) + } + } return } } diff --git a/internal/indexnode/util_test.go b/internal/indexnode/util_test.go index 6d7d98e823..53c59683ad 100644 --- a/internal/indexnode/util_test.go +++ b/internal/indexnode/util_test.go @@ -17,6 +17,7 @@ package indexnode import ( + "math/rand" "testing" "github.com/stretchr/testify/suite" @@ -39,3 +40,19 @@ func (s *utilSuite) Test_mapToKVPairs() { func Test_utilSuite(t *testing.T) { suite.Run(t, new(utilSuite)) } + +func generateFloats(num int) []float32 { + data := make([]float32, num) + for i := 0; i < num; i++ { + data[i] = rand.Float32() + } + return data +} + +func generateLongs(num int) []int64 { + data := make([]int64, num) + for i := 0; i < num; i++ { + data[i] = rand.Int63() + } + return data +} diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 0046baca7a..26f09922d3 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -7,6 +7,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -152,6 +153,10 @@ type DataCoordCatalog interface { ListCompactionTask(ctx context.Context) ([]*datapb.CompactionTask, error) SaveCompactionTask(ctx context.Context, task *datapb.CompactionTask) error DropCompactionTask(ctx context.Context, task *datapb.CompactionTask) error + + ListAnalyzeTasks(ctx context.Context) ([]*indexpb.AnalyzeTask, error) + SaveAnalyzeTask(ctx context.Context, task *indexpb.AnalyzeTask) error + DropAnalyzeTask(ctx context.Context, taskID typeutil.UniqueID) error } type QueryCoordCatalog interface { diff --git a/internal/metastore/kv/datacoord/constant.go b/internal/metastore/kv/datacoord/constant.go index d53ccbddb7..8e1c5d3588 100644 --- a/internal/metastore/kv/datacoord/constant.go +++ b/internal/metastore/kv/datacoord/constant.go @@ -28,6 +28,7 @@ const ( ImportTaskPrefix = MetaPrefix + "/import-task" PreImportTaskPrefix = MetaPrefix + "/preimport-task" CompactionTaskPrefix = MetaPrefix + "/compaction-task" + AnalyzeTaskPrefix = MetaPrefix + "/analyze-task" NonRemoveFlagTomestone = "non-removed" RemoveFlagTomestone = "removed" diff --git a/internal/metastore/kv/datacoord/kv_catalog.go b/internal/metastore/kv/datacoord/kv_catalog.go index de0e575432..f82f44b3e2 100644 --- a/internal/metastore/kv/datacoord/kv_catalog.go +++ b/internal/metastore/kv/datacoord/kv_catalog.go @@ -834,3 +834,41 @@ func (kc *Catalog) DropCompactionTask(ctx context.Context, task *datapb.Compacti key := buildCompactionTaskPath(task) return kc.MetaKv.Remove(key) } + +func (kc *Catalog) ListAnalyzeTasks(ctx context.Context) ([]*indexpb.AnalyzeTask, error) { + tasks := make([]*indexpb.AnalyzeTask, 0) + + _, values, err := kc.MetaKv.LoadWithPrefix(AnalyzeTaskPrefix) + if err != nil { + return nil, err + } + for _, value := range values { + task := &indexpb.AnalyzeTask{} + err = proto.Unmarshal([]byte(value), task) + if err != nil { + return nil, err + } + tasks = append(tasks, task) + } + return tasks, nil +} + +func (kc *Catalog) SaveAnalyzeTask(ctx context.Context, task *indexpb.AnalyzeTask) error { + key := buildAnalyzeTaskKey(task.TaskID) + + value, err := proto.Marshal(task) + if err != nil { + return err + } + + err = kc.MetaKv.Save(key, string(value)) + if err != nil { + return err + } + return nil +} + +func (kc *Catalog) DropAnalyzeTask(ctx context.Context, taskID typeutil.UniqueID) error { + key := buildAnalyzeTaskKey(taskID) + return kc.MetaKv.Remove(key) +} diff --git a/internal/metastore/kv/datacoord/util.go b/internal/metastore/kv/datacoord/util.go index ffffafbd78..2d9292950e 100644 --- a/internal/metastore/kv/datacoord/util.go +++ b/internal/metastore/kv/datacoord/util.go @@ -325,3 +325,7 @@ func buildImportTaskKey(taskID int64) string { func buildPreImportTaskKey(taskID int64) string { return fmt.Sprintf("%s/%d", PreImportTaskPrefix, taskID) } + +func buildAnalyzeTaskKey(taskID int64) string { + return fmt.Sprintf("%s/%d", AnalyzeTaskPrefix, taskID) +} diff --git a/internal/metastore/mocks/mock_datacoord_catalog.go b/internal/metastore/mocks/mock_datacoord_catalog.go index 42ba905dba..cdcadfa5fd 100644 --- a/internal/metastore/mocks/mock_datacoord_catalog.go +++ b/internal/metastore/mocks/mock_datacoord_catalog.go @@ -6,6 +6,7 @@ import ( context "context" datapb "github.com/milvus-io/milvus/internal/proto/datapb" + indexpb "github.com/milvus-io/milvus/internal/proto/indexpb" metastore "github.com/milvus-io/milvus/internal/metastore" @@ -345,6 +346,49 @@ func (_c *DataCoordCatalog_CreateSegmentIndex_Call) RunAndReturn(run func(contex return _c } +// DropAnalyzeTask provides a mock function with given fields: ctx, taskID +func (_m *DataCoordCatalog) DropAnalyzeTask(ctx context.Context, taskID int64) error { + ret := _m.Called(ctx, taskID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, taskID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DataCoordCatalog_DropAnalyzeTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropAnalyzeTask' +type DataCoordCatalog_DropAnalyzeTask_Call struct { + *mock.Call +} + +// DropAnalyzeTask is a helper method to define mock.On call +// - ctx context.Context +// - taskID int64 +func (_e *DataCoordCatalog_Expecter) DropAnalyzeTask(ctx interface{}, taskID interface{}) *DataCoordCatalog_DropAnalyzeTask_Call { + return &DataCoordCatalog_DropAnalyzeTask_Call{Call: _e.mock.On("DropAnalyzeTask", ctx, taskID)} +} + +func (_c *DataCoordCatalog_DropAnalyzeTask_Call) Run(run func(ctx context.Context, taskID int64)) *DataCoordCatalog_DropAnalyzeTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *DataCoordCatalog_DropAnalyzeTask_Call) Return(_a0 error) *DataCoordCatalog_DropAnalyzeTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *DataCoordCatalog_DropAnalyzeTask_Call) RunAndReturn(run func(context.Context, int64) error) *DataCoordCatalog_DropAnalyzeTask_Call { + _c.Call.Return(run) + return _c +} + // DropChannel provides a mock function with given fields: ctx, channel func (_m *DataCoordCatalog) DropChannel(ctx context.Context, channel string) error { ret := _m.Called(ctx, channel) @@ -777,6 +821,60 @@ func (_c *DataCoordCatalog_GcConfirm_Call) RunAndReturn(run func(context.Context return _c } +// ListAnalyzeTasks provides a mock function with given fields: ctx +func (_m *DataCoordCatalog) ListAnalyzeTasks(ctx context.Context) ([]*indexpb.AnalyzeTask, error) { + ret := _m.Called(ctx) + + var r0 []*indexpb.AnalyzeTask + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*indexpb.AnalyzeTask, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*indexpb.AnalyzeTask); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*indexpb.AnalyzeTask) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DataCoordCatalog_ListAnalyzeTasks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAnalyzeTasks' +type DataCoordCatalog_ListAnalyzeTasks_Call struct { + *mock.Call +} + +// ListAnalyzeTasks is a helper method to define mock.On call +// - ctx context.Context +func (_e *DataCoordCatalog_Expecter) ListAnalyzeTasks(ctx interface{}) *DataCoordCatalog_ListAnalyzeTasks_Call { + return &DataCoordCatalog_ListAnalyzeTasks_Call{Call: _e.mock.On("ListAnalyzeTasks", ctx)} +} + +func (_c *DataCoordCatalog_ListAnalyzeTasks_Call) Run(run func(ctx context.Context)) *DataCoordCatalog_ListAnalyzeTasks_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *DataCoordCatalog_ListAnalyzeTasks_Call) Return(_a0 []*indexpb.AnalyzeTask, _a1 error) *DataCoordCatalog_ListAnalyzeTasks_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *DataCoordCatalog_ListAnalyzeTasks_Call) RunAndReturn(run func(context.Context) ([]*indexpb.AnalyzeTask, error)) *DataCoordCatalog_ListAnalyzeTasks_Call { + _c.Call.Return(run) + return _c +} + // ListChannelCheckpoint provides a mock function with given fields: ctx func (_m *DataCoordCatalog) ListChannelCheckpoint(ctx context.Context) (map[string]*msgpb.MsgPosition, error) { ret := _m.Called(ctx) @@ -1292,6 +1390,49 @@ func (_c *DataCoordCatalog_MarkChannelDeleted_Call) RunAndReturn(run func(contex return _c } +// SaveAnalyzeTask provides a mock function with given fields: ctx, task +func (_m *DataCoordCatalog) SaveAnalyzeTask(ctx context.Context, task *indexpb.AnalyzeTask) error { + ret := _m.Called(ctx, task) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.AnalyzeTask) error); ok { + r0 = rf(ctx, task) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DataCoordCatalog_SaveAnalyzeTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveAnalyzeTask' +type DataCoordCatalog_SaveAnalyzeTask_Call struct { + *mock.Call +} + +// SaveAnalyzeTask is a helper method to define mock.On call +// - ctx context.Context +// - task *indexpb.AnalyzeTask +func (_e *DataCoordCatalog_Expecter) SaveAnalyzeTask(ctx interface{}, task interface{}) *DataCoordCatalog_SaveAnalyzeTask_Call { + return &DataCoordCatalog_SaveAnalyzeTask_Call{Call: _e.mock.On("SaveAnalyzeTask", ctx, task)} +} + +func (_c *DataCoordCatalog_SaveAnalyzeTask_Call) Run(run func(ctx context.Context, task *indexpb.AnalyzeTask)) *DataCoordCatalog_SaveAnalyzeTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*indexpb.AnalyzeTask)) + }) + return _c +} + +func (_c *DataCoordCatalog_SaveAnalyzeTask_Call) Return(_a0 error) *DataCoordCatalog_SaveAnalyzeTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *DataCoordCatalog_SaveAnalyzeTask_Call) RunAndReturn(run func(context.Context, *indexpb.AnalyzeTask) error) *DataCoordCatalog_SaveAnalyzeTask_Call { + _c.Call.Return(run) + return _c +} + // SaveChannelCheckpoint provides a mock function with given fields: ctx, vChannel, pos func (_m *DataCoordCatalog) SaveChannelCheckpoint(ctx context.Context, vChannel string, pos *msgpb.MsgPosition) error { ret := _m.Called(ctx, vChannel, pos) diff --git a/internal/mocks/mock_indexnode.go b/internal/mocks/mock_indexnode.go index bcd158dc7b..f81dba0c08 100644 --- a/internal/mocks/mock_indexnode.go +++ b/internal/mocks/mock_indexnode.go @@ -85,6 +85,61 @@ func (_c *MockIndexNode_CreateJob_Call) RunAndReturn(run func(context.Context, * return _c } +// CreateJobV2 provides a mock function with given fields: _a0, _a1 +func (_m *MockIndexNode) CreateJobV2(_a0 context.Context, _a1 *indexpb.CreateJobV2Request) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.CreateJobV2Request) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNode_CreateJobV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateJobV2' +type MockIndexNode_CreateJobV2_Call struct { + *mock.Call +} + +// CreateJobV2 is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *indexpb.CreateJobV2Request +func (_e *MockIndexNode_Expecter) CreateJobV2(_a0 interface{}, _a1 interface{}) *MockIndexNode_CreateJobV2_Call { + return &MockIndexNode_CreateJobV2_Call{Call: _e.mock.On("CreateJobV2", _a0, _a1)} +} + +func (_c *MockIndexNode_CreateJobV2_Call) Run(run func(_a0 context.Context, _a1 *indexpb.CreateJobV2Request)) *MockIndexNode_CreateJobV2_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*indexpb.CreateJobV2Request)) + }) + return _c +} + +func (_c *MockIndexNode_CreateJobV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNode_CreateJobV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNode_CreateJobV2_Call) RunAndReturn(run func(context.Context, *indexpb.CreateJobV2Request) (*commonpb.Status, error)) *MockIndexNode_CreateJobV2_Call { + _c.Call.Return(run) + return _c +} + // DropJobs provides a mock function with given fields: _a0, _a1 func (_m *MockIndexNode) DropJobs(_a0 context.Context, _a1 *indexpb.DropJobsRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) @@ -140,6 +195,61 @@ func (_c *MockIndexNode_DropJobs_Call) RunAndReturn(run func(context.Context, *i return _c } +// DropJobsV2 provides a mock function with given fields: _a0, _a1 +func (_m *MockIndexNode) DropJobsV2(_a0 context.Context, _a1 *indexpb.DropJobsV2Request) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DropJobsV2Request) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNode_DropJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropJobsV2' +type MockIndexNode_DropJobsV2_Call struct { + *mock.Call +} + +// DropJobsV2 is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *indexpb.DropJobsV2Request +func (_e *MockIndexNode_Expecter) DropJobsV2(_a0 interface{}, _a1 interface{}) *MockIndexNode_DropJobsV2_Call { + return &MockIndexNode_DropJobsV2_Call{Call: _e.mock.On("DropJobsV2", _a0, _a1)} +} + +func (_c *MockIndexNode_DropJobsV2_Call) Run(run func(_a0 context.Context, _a1 *indexpb.DropJobsV2Request)) *MockIndexNode_DropJobsV2_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*indexpb.DropJobsV2Request)) + }) + return _c +} + +func (_c *MockIndexNode_DropJobsV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNode_DropJobsV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNode_DropJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.DropJobsV2Request) (*commonpb.Status, error)) *MockIndexNode_DropJobsV2_Call { + _c.Call.Return(run) + return _c +} + // GetAddress provides a mock function with given fields: func (_m *MockIndexNode) GetAddress() string { ret := _m.Called() @@ -497,6 +607,61 @@ func (_c *MockIndexNode_QueryJobs_Call) RunAndReturn(run func(context.Context, * return _c } +// QueryJobsV2 provides a mock function with given fields: _a0, _a1 +func (_m *MockIndexNode) QueryJobsV2(_a0 context.Context, _a1 *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) { + ret := _m.Called(_a0, _a1) + + var r0 *indexpb.QueryJobsV2Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request) *indexpb.QueryJobsV2Response); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.QueryJobsV2Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.QueryJobsV2Request) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNode_QueryJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryJobsV2' +type MockIndexNode_QueryJobsV2_Call struct { + *mock.Call +} + +// QueryJobsV2 is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *indexpb.QueryJobsV2Request +func (_e *MockIndexNode_Expecter) QueryJobsV2(_a0 interface{}, _a1 interface{}) *MockIndexNode_QueryJobsV2_Call { + return &MockIndexNode_QueryJobsV2_Call{Call: _e.mock.On("QueryJobsV2", _a0, _a1)} +} + +func (_c *MockIndexNode_QueryJobsV2_Call) Run(run func(_a0 context.Context, _a1 *indexpb.QueryJobsV2Request)) *MockIndexNode_QueryJobsV2_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*indexpb.QueryJobsV2Request)) + }) + return _c +} + +func (_c *MockIndexNode_QueryJobsV2_Call) Return(_a0 *indexpb.QueryJobsV2Response, _a1 error) *MockIndexNode_QueryJobsV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNode_QueryJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error)) *MockIndexNode_QueryJobsV2_Call { + _c.Call.Return(run) + return _c +} + // Register provides a mock function with given fields: func (_m *MockIndexNode) Register() error { ret := _m.Called() diff --git a/internal/mocks/mock_indexnode_client.go b/internal/mocks/mock_indexnode_client.go index 1e30de98ac..b21963a6b5 100644 --- a/internal/mocks/mock_indexnode_client.go +++ b/internal/mocks/mock_indexnode_client.go @@ -142,6 +142,76 @@ func (_c *MockIndexNodeClient_CreateJob_Call) RunAndReturn(run func(context.Cont return _c } +// CreateJobV2 provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) CreateJobV2(ctx context.Context, in *indexpb.CreateJobV2Request, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_CreateJobV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateJobV2' +type MockIndexNodeClient_CreateJobV2_Call struct { + *mock.Call +} + +// CreateJobV2 is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.CreateJobV2Request +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) CreateJobV2(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_CreateJobV2_Call { + return &MockIndexNodeClient_CreateJobV2_Call{Call: _e.mock.On("CreateJobV2", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_CreateJobV2_Call) Run(run func(ctx context.Context, in *indexpb.CreateJobV2Request, opts ...grpc.CallOption)) *MockIndexNodeClient_CreateJobV2_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.CreateJobV2Request), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_CreateJobV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNodeClient_CreateJobV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_CreateJobV2_Call) RunAndReturn(run func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) (*commonpb.Status, error)) *MockIndexNodeClient_CreateJobV2_Call { + _c.Call.Return(run) + return _c +} + // DropJobs provides a mock function with given fields: ctx, in, opts func (_m *MockIndexNodeClient) DropJobs(ctx context.Context, in *indexpb.DropJobsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { _va := make([]interface{}, len(opts)) @@ -212,6 +282,76 @@ func (_c *MockIndexNodeClient_DropJobs_Call) RunAndReturn(run func(context.Conte return _c } +// DropJobsV2 provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) DropJobsV2(ctx context.Context, in *indexpb.DropJobsV2Request, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_DropJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropJobsV2' +type MockIndexNodeClient_DropJobsV2_Call struct { + *mock.Call +} + +// DropJobsV2 is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.DropJobsV2Request +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) DropJobsV2(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_DropJobsV2_Call { + return &MockIndexNodeClient_DropJobsV2_Call{Call: _e.mock.On("DropJobsV2", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_DropJobsV2_Call) Run(run func(ctx context.Context, in *indexpb.DropJobsV2Request, opts ...grpc.CallOption)) *MockIndexNodeClient_DropJobsV2_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.DropJobsV2Request), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_DropJobsV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNodeClient_DropJobsV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_DropJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) (*commonpb.Status, error)) *MockIndexNodeClient_DropJobsV2_Call { + _c.Call.Return(run) + return _c +} + // GetComponentStates provides a mock function with given fields: ctx, in, opts func (_m *MockIndexNodeClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { _va := make([]interface{}, len(opts)) @@ -562,6 +702,76 @@ func (_c *MockIndexNodeClient_QueryJobs_Call) RunAndReturn(run func(context.Cont return _c } +// QueryJobsV2 provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) QueryJobsV2(ctx context.Context, in *indexpb.QueryJobsV2Request, opts ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.QueryJobsV2Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) *indexpb.QueryJobsV2Response); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.QueryJobsV2Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_QueryJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryJobsV2' +type MockIndexNodeClient_QueryJobsV2_Call struct { + *mock.Call +} + +// QueryJobsV2 is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.QueryJobsV2Request +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) QueryJobsV2(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_QueryJobsV2_Call { + return &MockIndexNodeClient_QueryJobsV2_Call{Call: _e.mock.On("QueryJobsV2", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_QueryJobsV2_Call) Run(run func(ctx context.Context, in *indexpb.QueryJobsV2Request, opts ...grpc.CallOption)) *MockIndexNodeClient_QueryJobsV2_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.QueryJobsV2Request), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_QueryJobsV2_Call) Return(_a0 *indexpb.QueryJobsV2Response, _a1 error) *MockIndexNodeClient_QueryJobsV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_QueryJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error)) *MockIndexNodeClient_QueryJobsV2_Call { + _c.Call.Return(run) + return _c +} + // ShowConfigurations provides a mock function with given fields: ctx, in, opts func (_m *MockIndexNodeClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { _va := make([]interface{}, len(opts)) diff --git a/internal/proto/clustering.proto b/internal/proto/clustering.proto new file mode 100644 index 0000000000..5d45889ffb --- /dev/null +++ b/internal/proto/clustering.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; +package milvus.proto.clustering; + +option go_package = "github.com/milvus-io/milvus/internal/proto/clusteringpb"; +import "schema.proto"; + +// Synchronously modify StorageConfig in index_coord.proto/index_cgo_msg.proto file +message StorageConfig { + string address = 1; + string access_keyID = 2; + string secret_access_key = 3; + bool useSSL = 4; + string bucket_name = 5; + string root_path = 6; + bool useIAM = 7; + string IAMEndpoint = 8; + string storage_type = 9; + bool use_virtual_host = 10; + string region = 11; + string cloud_provider = 12; + int64 request_timeout_ms = 13; + string sslCACert = 14; +} + +message InsertFiles { + repeated string insert_files = 1; +} + +message AnalyzeInfo { + string clusterID = 1; + int64 buildID = 2; + int64 collectionID = 3; + int64 partitionID = 4; + int64 segmentID = 5; + int64 version = 6; + int64 dim = 7; + int64 num_clusters = 8; + int64 train_size = 9; + double min_cluster_ratio = 10; // min_cluster_size / avg_cluster_size < min_cluster_ratio, is skew + double max_cluster_ratio = 11; // max_cluster_size / avg_cluster_size > max_cluster_ratio, is skew + int64 max_cluster_size = 12; + map insert_files = 13; + map num_rows = 14; + schema.FieldSchema field_schema = 15; + StorageConfig storage_config = 16; +} + +message ClusteringCentroidsStats { + repeated schema.VectorField centroids = 1; +} + +message ClusteringCentroidIdMappingStats { + repeated uint32 centroid_id_mapping = 1; + repeated int64 num_in_centroid = 2; +} \ No newline at end of file diff --git a/internal/proto/index_coord.proto b/internal/proto/index_coord.proto index 0c0cea0361..b188dcbcf4 100644 --- a/internal/proto/index_coord.proto +++ b/internal/proto/index_coord.proto @@ -10,19 +10,19 @@ import "milvus.proto"; import "schema.proto"; service IndexCoord { - rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {} - rpc GetStatisticsChannel(internal.GetStatisticsChannelRequest) returns(milvus.StringResponse){} - rpc CreateIndex(CreateIndexRequest) returns (common.Status){} - rpc AlterIndex(AlterIndexRequest) returns (common.Status){} - // Deprecated: use DescribeIndex instead - rpc GetIndexState(GetIndexStateRequest) returns (GetIndexStateResponse) {} - rpc GetSegmentIndexState(GetSegmentIndexStateRequest) returns (GetSegmentIndexStateResponse) {} - rpc GetIndexInfos(GetIndexInfoRequest) returns (GetIndexInfoResponse){} - rpc DropIndex(DropIndexRequest) returns (common.Status) {} - rpc DescribeIndex(DescribeIndexRequest) returns (DescribeIndexResponse) {} - rpc GetIndexStatistics(GetIndexStatisticsRequest) returns (GetIndexStatisticsResponse) {} - // Deprecated: use DescribeIndex instead - rpc GetIndexBuildProgress(GetIndexBuildProgressRequest) returns (GetIndexBuildProgressResponse) {} + rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {} + rpc GetStatisticsChannel(internal.GetStatisticsChannelRequest) returns(milvus.StringResponse){} + rpc CreateIndex(CreateIndexRequest) returns (common.Status){} + rpc AlterIndex(AlterIndexRequest) returns (common.Status){} + // Deprecated: use DescribeIndex instead + rpc GetIndexState(GetIndexStateRequest) returns (GetIndexStateResponse) {} + rpc GetSegmentIndexState(GetSegmentIndexStateRequest) returns (GetSegmentIndexStateResponse) {} + rpc GetIndexInfos(GetIndexInfoRequest) returns (GetIndexInfoResponse){} + rpc DropIndex(DropIndexRequest) returns (common.Status) {} + rpc DescribeIndex(DescribeIndexRequest) returns (DescribeIndexResponse) {} + rpc GetIndexStatistics(GetIndexStatisticsRequest) returns (GetIndexStatisticsResponse) {} + // Deprecated: use DescribeIndex instead + rpc GetIndexBuildProgress(GetIndexBuildProgressRequest) returns (GetIndexBuildProgressResponse) {} rpc ShowConfigurations(internal.ShowConfigurationsRequest) returns (internal.ShowConfigurationsResponse) { @@ -60,6 +60,13 @@ service IndexNode { rpc GetMetrics(milvus.GetMetricsRequest) returns (milvus.GetMetricsResponse) { } + + rpc CreateJobV2(CreateJobV2Request) returns (common.Status) { + } + rpc QueryJobsV2(QueryJobsV2Request) returns (QueryJobsV2Response) { + } + rpc DropJobsV2(DropJobsV2Request) returns (common.Status) { + } } message IndexInfo { @@ -159,9 +166,9 @@ message CreateIndexRequest { } message AlterIndexRequest { - int64 collectionID = 1; - string index_name = 2; - repeated common.KeyValuePair params = 3; + int64 collectionID = 1; + string index_name = 2; + repeated common.KeyValuePair params = 3; } message GetIndexInfoRequest { @@ -226,7 +233,7 @@ message GetIndexBuildProgressResponse { int64 pending_index_rows = 4; } -// Synchronously modify StorageConfig in index_cgo_msg.proto file +// Synchronously modify StorageConfig in index_cgo_msg.proto/clustering.proto file message StorageConfig { string address = 1; string access_keyID = 2; @@ -246,11 +253,11 @@ message StorageConfig { // Synchronously modify OptionalFieldInfo in index_cgo_msg.proto file message OptionalFieldInfo { - int64 fieldID = 1; - string field_name = 2; - int32 field_type = 3; - repeated string data_paths = 4; - repeated int64 data_ids = 5; + int64 fieldID = 1; + string field_name = 2; + int32 field_type = 3; + repeated string data_paths = 4; + repeated int64 data_ids = 5; } message CreateJobRequest { @@ -347,3 +354,108 @@ message ListIndexesResponse { common.Status status = 1; repeated IndexInfo index_infos = 2; } + +message AnalyzeTask { + int64 collectionID = 1; + int64 partitionID = 2; + int64 fieldID = 3; + string field_name = 4; + schema.DataType field_type = 5; + int64 taskID = 6; + int64 version = 7; + repeated int64 segmentIDs = 8; + int64 nodeID = 9; + JobState state = 10; + string fail_reason = 11; + int64 dim = 12; + string centroids_file = 13; +} + +message SegmentStats { + int64 ID = 1; + int64 num_rows = 2; + repeated int64 logIDs = 3; +} + +message AnalyzeRequest { + string clusterID = 1; + int64 taskID = 2; + int64 collectionID = 3; + int64 partitionID = 4; + int64 fieldID = 5; + string fieldName = 6; + schema.DataType field_type = 7; + map segment_stats = 8; + int64 version = 9; + StorageConfig storage_config = 10; + int64 dim = 11; + double max_train_size_ratio = 12; + int64 num_clusters = 13; + schema.FieldSchema field = 14; + double min_cluster_size_ratio = 15; + double max_cluster_size_ratio = 16; + int64 max_cluster_size = 17; +} + +message AnalyzeResult { + int64 taskID = 1; + JobState state = 2; + string fail_reason = 3; + string centroids_file = 4; +} + +enum JobType { + JobTypeNone = 0; + JobTypeIndexJob = 1; + JobTypeAnalyzeJob = 2; +} + +message CreateJobV2Request { + string clusterID = 1; + int64 taskID = 2; + JobType job_type = 3; + oneof request { + AnalyzeRequest analyze_request = 4; + CreateJobRequest index_request = 5; + } + // JobDescriptor job = 3; +} + +message QueryJobsV2Request { + string clusterID = 1; + repeated int64 taskIDs = 2; + JobType job_type = 3; +} + +message IndexJobResults { + repeated IndexTaskInfo results = 1; +} + +message AnalyzeResults { + repeated AnalyzeResult results = 1; +} + +message QueryJobsV2Response { + common.Status status = 1; + string clusterID = 2; + oneof result { + IndexJobResults index_job_results = 3; + AnalyzeResults analyze_job_results = 4; + } +} + +message DropJobsV2Request { + string clusterID = 1; + repeated int64 taskIDs = 2; + JobType job_type = 3; +} + + +enum JobState { + JobStateNone = 0; + JobStateInit = 1; + JobStateInProgress = 2; + JobStateFinished = 3; + JobStateFailed = 4; + JobStateRetry = 5; +} \ No newline at end of file diff --git a/internal/util/analyzecgowrapper/analyze.go b/internal/util/analyzecgowrapper/analyze.go new file mode 100644 index 0000000000..1b5b631194 --- /dev/null +++ b/internal/util/analyzecgowrapper/analyze.go @@ -0,0 +1,116 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzecgowrapper + +/* +#cgo pkg-config: milvus_clustering + +#include // free +#include "clustering/analyze_c.h" +*/ +import "C" + +import ( + "context" + "runtime" + "unsafe" + + "github.com/golang/protobuf/proto" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/clusteringpb" + "github.com/milvus-io/milvus/pkg/log" +) + +type CodecAnalyze interface { + Delete() error + GetResult(size int) (string, int64, []string, []int64, error) +} + +func Analyze(ctx context.Context, analyzeInfo *clusteringpb.AnalyzeInfo) (CodecAnalyze, error) { + analyzeInfoBlob, err := proto.Marshal(analyzeInfo) + if err != nil { + log.Ctx(ctx).Warn("marshal analyzeInfo failed", + zap.Int64("buildID", analyzeInfo.GetBuildID()), + zap.Error(err)) + return nil, err + } + var analyzePtr C.CAnalyze + status := C.Analyze(&analyzePtr, (*C.uint8_t)(unsafe.Pointer(&analyzeInfoBlob[0])), (C.uint64_t)(len(analyzeInfoBlob))) + if err := HandleCStatus(&status, "failed to analyze task"); err != nil { + return nil, err + } + + analyze := &CgoAnalyze{ + analyzePtr: analyzePtr, + close: false, + } + + runtime.SetFinalizer(analyze, func(ca *CgoAnalyze) { + if ca != nil && !ca.close { + log.Error("there is leakage in analyze object, please check.") + } + }) + + return analyze, nil +} + +type CgoAnalyze struct { + analyzePtr C.CAnalyze + close bool +} + +func (ca *CgoAnalyze) Delete() error { + if ca.close { + return nil + } + var status C.CStatus + if ca.analyzePtr != nil { + status = C.DeleteAnalyze(ca.analyzePtr) + } + ca.close = true + return HandleCStatus(&status, "failed to delete analyze") +} + +func (ca *CgoAnalyze) GetResult(size int) (string, int64, []string, []int64, error) { + cOffsetMappingFilesPath := make([]unsafe.Pointer, size) + cOffsetMappingFilesSize := make([]C.int64_t, size) + cCentroidsFilePath := C.CString("") + cCentroidsFileSize := C.int64_t(0) + defer C.free(unsafe.Pointer(cCentroidsFilePath)) + + status := C.GetAnalyzeResultMeta(ca.analyzePtr, + &cCentroidsFilePath, + &cCentroidsFileSize, + unsafe.Pointer(&cOffsetMappingFilesPath[0]), + &cOffsetMappingFilesSize[0], + ) + if err := HandleCStatus(&status, "failed to delete analyze"); err != nil { + return "", 0, nil, nil, err + } + offsetMappingFilesPath := make([]string, size) + offsetMappingFilesSize := make([]int64, size) + centroidsFilePath := C.GoString(cCentroidsFilePath) + centroidsFileSize := int64(cCentroidsFileSize) + + for i := 0; i < size; i++ { + offsetMappingFilesPath[i] = C.GoString((*C.char)(cOffsetMappingFilesPath[i])) + offsetMappingFilesSize[i] = int64(cOffsetMappingFilesSize[i]) + } + + return centroidsFilePath, centroidsFileSize, offsetMappingFilesPath, offsetMappingFilesSize, nil +} diff --git a/internal/util/analyzecgowrapper/helper.go b/internal/util/analyzecgowrapper/helper.go new file mode 100644 index 0000000000..5b2f0b8fcc --- /dev/null +++ b/internal/util/analyzecgowrapper/helper.go @@ -0,0 +1,55 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzecgowrapper + +/* + +#cgo pkg-config: milvus_common + +#include // free +#include "common/type_c.h" +*/ +import "C" + +import ( + "fmt" + "unsafe" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +// HandleCStatus deal with the error returned from CGO +func HandleCStatus(status *C.CStatus, extraInfo string) error { + if status.error_code == 0 { + return nil + } + errorCode := int(status.error_code) + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + + logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, errorMsg) + log.Warn(logMsg) + if errorCode == 2003 { + return merr.WrapErrSegcoreUnsupported(int32(errorCode), logMsg) + } + if errorCode == 2033 { + log.Info("fake finished the task") + return merr.ErrSegcorePretendFinished + } + return merr.WrapErrSegcore(int32(errorCode), logMsg) +} diff --git a/internal/util/mock/grpc_indexnode_client.go b/internal/util/mock/grpc_indexnode_client.go index d8bbbd57c7..ae180cd731 100644 --- a/internal/util/mock/grpc_indexnode_client.go +++ b/internal/util/mock/grpc_indexnode_client.go @@ -69,6 +69,18 @@ func (m *GrpcIndexNodeClient) ShowConfigurations(ctx context.Context, in *intern return &internalpb.ShowConfigurationsResponse{}, m.Err } +func (m *GrpcIndexNodeClient) CreateJobV2(ctx context.Context, in *indexpb.CreateJobV2Request, opt ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcIndexNodeClient) QueryJobsV2(ctx context.Context, in *indexpb.QueryJobsV2Request, opt ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) { + return &indexpb.QueryJobsV2Response{}, m.Err +} + +func (m *GrpcIndexNodeClient) DropJobsV2(ctx context.Context, in *indexpb.DropJobsV2Request, opt ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + func (m *GrpcIndexNodeClient) Close() error { return m.Err } diff --git a/pkg/common/common.go b/pkg/common/common.go index 723f231718..fc23db43ac 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -93,6 +93,11 @@ const ( // PartitionStatsPath storage path const for partition stats files PartitionStatsPath = `part_stats` + + // AnalyzeStatsPath storage path const for analyze. + AnalyzeStatsPath = `analyze_stats` + OffsetMapping = `offset_mapping` + Centroids = "centroids" ) // Search, Index parameter keys diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 4b8c56d069..6c4baa873c 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -248,6 +248,9 @@ type commonConfig struct { BloomFilterType ParamItem `refreshable:"true"` MaxBloomFalsePositive ParamItem `refreshable:"true"` PanicWhenPluginFail ParamItem `refreshable:"false"` + + UsePartitionKeyAsClusteringKey ParamItem `refreshable:"true"` + UseVectorAsClusteringKey ParamItem `refreshable:"true"` } func (p *commonConfig) init(base *BaseTable) { @@ -761,6 +764,22 @@ like the old password verification when updating the credential`, Doc: "panic or not when plugin fail to init", } p.PanicWhenPluginFail.Init(base.mgr) + + p.UsePartitionKeyAsClusteringKey = ParamItem{ + Key: "common.usePartitionKeyAsClusteringKey", + Version: "2.4.2", + Doc: "if true, do clustering compaction and segment prune on partition key field", + DefaultValue: "false", + } + p.UsePartitionKeyAsClusteringKey.Init(base.mgr) + + p.UseVectorAsClusteringKey = ParamItem{ + Key: "common.useVectorAsClusteringKey", + Version: "2.4.2", + Doc: "if true, do clustering compaction and segment prune on vector field", + DefaultValue: "false", + } + p.UseVectorAsClusteringKey.Init(base.mgr) } type gpuConfig struct { @@ -2702,7 +2721,7 @@ user-task-polling: p.DefaultSegmentFilterRatio = ParamItem{ Key: "queryNode.defaultSegmentFilterRatio", Version: "2.4.0", - DefaultValue: "0.5", + DefaultValue: "2", Doc: "filter ratio used for pruning segments when searching", } p.DefaultSegmentFilterRatio.Init(base.mgr) @@ -2772,6 +2791,26 @@ type dataCoordConfig struct { ChannelCheckpointMaxLag ParamItem `refreshable:"true"` SyncSegmentsInterval ParamItem `refreshable:"false"` + // Clustering Compaction + ClusteringCompactionEnable ParamItem `refreshable:"true"` + ClusteringCompactionAutoEnable ParamItem `refreshable:"true"` + ClusteringCompactionTriggerInterval ParamItem `refreshable:"true"` + ClusteringCompactionStateCheckInterval ParamItem `refreshable:"true"` + ClusteringCompactionGCInterval ParamItem `refreshable:"true"` + ClusteringCompactionMinInterval ParamItem `refreshable:"true"` + ClusteringCompactionMaxInterval ParamItem `refreshable:"true"` + ClusteringCompactionNewDataSizeThreshold ParamItem `refreshable:"true"` + ClusteringCompactionDropTolerance ParamItem `refreshable:"true"` + ClusteringCompactionPreferSegmentSize ParamItem `refreshable:"true"` + ClusteringCompactionMaxSegmentSize ParamItem `refreshable:"true"` + ClusteringCompactionMaxTrainSizeRatio ParamItem `refreshable:"true"` + ClusteringCompactionTimeoutInSeconds ParamItem `refreshable:"true"` + ClusteringCompactionMaxCentroidsNum ParamItem `refreshable:"true"` + ClusteringCompactionMinCentroidsNum ParamItem `refreshable:"true"` + ClusteringCompactionMinClusterSizeRatio ParamItem `refreshable:"true"` + ClusteringCompactionMaxClusterSizeRatio ParamItem `refreshable:"true"` + ClusteringCompactionMaxClusterSize ParamItem `refreshable:"true"` + // LevelZero Segment EnableLevelZeroSegment ParamItem `refreshable:"false"` LevelZeroCompactionTriggerMinSize ParamItem `refreshable:"true"` @@ -3172,6 +3211,156 @@ During compaction, the size of segment # of rows is able to exceed segment max # } p.LevelZeroCompactionTriggerDeltalogMaxNum.Init(base.mgr) + p.ClusteringCompactionEnable = ParamItem{ + Key: "dataCoord.compaction.clustering.enable", + Version: "2.4.2", + DefaultValue: "false", + Doc: "Enable clustering compaction", + Export: true, + } + p.ClusteringCompactionEnable.Init(base.mgr) + + p.ClusteringCompactionAutoEnable = ParamItem{ + Key: "dataCoord.compaction.clustering.autoEnable", + Version: "2.4.2", + DefaultValue: "false", + Doc: "Enable auto clustering compaction", + Export: true, + } + p.ClusteringCompactionAutoEnable.Init(base.mgr) + + p.ClusteringCompactionTriggerInterval = ParamItem{ + Key: "dataCoord.compaction.clustering.triggerInterval", + Version: "2.4.2", + DefaultValue: "600", + } + p.ClusteringCompactionTriggerInterval.Init(base.mgr) + + p.ClusteringCompactionStateCheckInterval = ParamItem{ + Key: "dataCoord.compaction.clustering.stateCheckInterval", + Version: "2.4.2", + DefaultValue: "10", + } + p.ClusteringCompactionStateCheckInterval.Init(base.mgr) + + p.ClusteringCompactionGCInterval = ParamItem{ + Key: "dataCoord.compaction.clustering.gcInterval", + Version: "2.4.2", + DefaultValue: "600", + } + p.ClusteringCompactionGCInterval.Init(base.mgr) + + p.ClusteringCompactionMinInterval = ParamItem{ + Key: "dataCoord.compaction.clustering.minInterval", + Version: "2.4.2", + Doc: "The minimum interval between clustering compaction executions of one collection, to avoid redundant compaction", + DefaultValue: "3600", + } + p.ClusteringCompactionMinInterval.Init(base.mgr) + + p.ClusteringCompactionMaxInterval = ParamItem{ + Key: "dataCoord.compaction.clustering.maxInterval", + Version: "2.4.2", + Doc: "If a collection haven't been clustering compacted for longer than maxInterval, force compact", + DefaultValue: "86400", + } + p.ClusteringCompactionMaxInterval.Init(base.mgr) + + p.ClusteringCompactionNewDataSizeThreshold = ParamItem{ + Key: "dataCoord.compaction.clustering.newDataSizeThreshold", + Version: "2.4.2", + Doc: "If new data size is large than newDataSizeThreshold, execute clustering compaction", + DefaultValue: "512m", + } + p.ClusteringCompactionNewDataSizeThreshold.Init(base.mgr) + + p.ClusteringCompactionTimeoutInSeconds = ParamItem{ + Key: "dataCoord.compaction.clustering.timeout", + Version: "2.4.2", + DefaultValue: "3600", + } + p.ClusteringCompactionTimeoutInSeconds.Init(base.mgr) + + p.ClusteringCompactionDropTolerance = ParamItem{ + Key: "dataCoord.compaction.clustering.dropTolerance", + Version: "2.4.2", + Doc: "If clustering compaction job is finished for a long time, gc it", + DefaultValue: "259200", + } + p.ClusteringCompactionDropTolerance.Init(base.mgr) + + p.ClusteringCompactionPreferSegmentSize = ParamItem{ + Key: "dataCoord.compaction.clustering.preferSegmentSize", + Version: "2.4.2", + DefaultValue: "512m", + PanicIfEmpty: false, + Export: true, + } + p.ClusteringCompactionPreferSegmentSize.Init(base.mgr) + + p.ClusteringCompactionMaxSegmentSize = ParamItem{ + Key: "dataCoord.compaction.clustering.maxSegmentSize", + Version: "2.4.2", + DefaultValue: "1024m", + PanicIfEmpty: false, + Export: true, + } + p.ClusteringCompactionMaxSegmentSize.Init(base.mgr) + + p.ClusteringCompactionMaxTrainSizeRatio = ParamItem{ + Key: "dataCoord.compaction.clustering.maxTrainSizeRatio", + Version: "2.4.2", + DefaultValue: "0.8", + Doc: "max data size ratio in Kmeans train, if larger than it, will down sampling to meet this limit", + Export: true, + } + p.ClusteringCompactionMaxTrainSizeRatio.Init(base.mgr) + + p.ClusteringCompactionMaxCentroidsNum = ParamItem{ + Key: "dataCoord.compaction.clustering.maxCentroidsNum", + Version: "2.4.2", + DefaultValue: "10240", + Doc: "maximum centroids number in Kmeans train", + Export: true, + } + p.ClusteringCompactionMaxCentroidsNum.Init(base.mgr) + + p.ClusteringCompactionMinCentroidsNum = ParamItem{ + Key: "dataCoord.compaction.clustering.minCentroidsNum", + Version: "2.4.2", + DefaultValue: "16", + Doc: "minimum centroids number in Kmeans train", + Export: true, + } + p.ClusteringCompactionMinCentroidsNum.Init(base.mgr) + + p.ClusteringCompactionMinClusterSizeRatio = ParamItem{ + Key: "dataCoord.compaction.clustering.minClusterSizeRatio", + Version: "2.4.2", + DefaultValue: "0.01", + Doc: "minimum cluster size / avg size in Kmeans train", + Export: true, + } + p.ClusteringCompactionMinClusterSizeRatio.Init(base.mgr) + + p.ClusteringCompactionMaxClusterSizeRatio = ParamItem{ + Key: "dataCoord.compaction.clustering.maxClusterSizeRatio", + Version: "2.4.2", + DefaultValue: "10", + Doc: "maximum cluster size / avg size in Kmeans train", + Export: true, + } + p.ClusteringCompactionMaxClusterSizeRatio.Init(base.mgr) + + p.ClusteringCompactionMaxClusterSize = ParamItem{ + Key: "dataCoord.compaction.clustering.maxClusterSize", + Version: "2.4.2", + DefaultValue: "5g", + Doc: "maximum cluster size in Kmeans train", + Export: true, + } + p.ClusteringCompactionMaxClusterSize.Init(base.mgr) + p.EnableGarbageCollection = ParamItem{ Key: "dataCoord.enableGarbageCollection", Version: "2.0.0", @@ -3475,6 +3664,10 @@ type dataNodeConfig struct { // slot SlotCap ParamItem `refreshable:"true"` + + // clustering compaction + ClusteringCompactionMemoryBufferRatio ParamItem `refreshable:"true"` + ClusteringCompactionWorkerPoolSize ParamItem `refreshable:"true"` } func (p *dataNodeConfig) init(base *BaseTable) { @@ -3789,6 +3982,26 @@ if this parameter <= 0, will set it as 10`, Export: true, } p.SlotCap.Init(base.mgr) + + p.ClusteringCompactionMemoryBufferRatio = ParamItem{ + Key: "datanode.clusteringCompaction.memoryBufferRatio", + Version: "2.4.2", + Doc: "The ratio of memory buffer of clustering compaction. Data larger than threshold will be spilled to storage.", + DefaultValue: "0.1", + PanicIfEmpty: false, + Export: true, + } + p.ClusteringCompactionMemoryBufferRatio.Init(base.mgr) + + p.ClusteringCompactionWorkerPoolSize = ParamItem{ + Key: "datanode.clusteringCompaction.cpu", + Version: "2.4.2", + Doc: "worker pool size for one clustering compaction job.", + DefaultValue: "1", + PanicIfEmpty: false, + Export: true, + } + p.ClusteringCompactionWorkerPoolSize.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 34e6d409c8..021445ce79 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -429,6 +429,23 @@ func TestComponentParam(t *testing.T) { params.Save("datacoord.gracefulStopTimeout", "100") assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second)) + + params.Save("dataCoord.compaction.clustering.enable", "true") + assert.Equal(t, true, Params.ClusteringCompactionEnable.GetAsBool()) + params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10") + assert.Equal(t, int64(10), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize()) + params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10k") + assert.Equal(t, int64(10*1024), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize()) + params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10m") + assert.Equal(t, int64(10*1024*1024), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize()) + params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10g") + assert.Equal(t, int64(10*1024*1024*1024), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize()) + params.Save("dataCoord.compaction.clustering.dropTolerance", "86400") + assert.Equal(t, int64(86400), Params.ClusteringCompactionDropTolerance.GetAsInt64()) + params.Save("dataCoord.compaction.clustering.maxSegmentSize", "100m") + assert.Equal(t, int64(100*1024*1024), Params.ClusteringCompactionMaxSegmentSize.GetAsSize()) + params.Save("dataCoord.compaction.clustering.preferSegmentSize", "10m") + assert.Equal(t, int64(10*1024*1024), Params.ClusteringCompactionPreferSegmentSize.GetAsSize()) }) t.Run("test dataNodeConfig", func(t *testing.T) { @@ -482,6 +499,9 @@ func TestComponentParam(t *testing.T) { params.Save("datanode.gracefulStopTimeout", "100") assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second)) assert.Equal(t, 2, Params.SlotCap.GetAsInt()) + // clustering compaction + params.Save("datanode.clusteringCompaction.memoryBufferRatio", "0.1") + assert.Equal(t, 0.1, Params.ClusteringCompactionMemoryBufferRatio.GetAsFloat()) }) t.Run("test indexNodeConfig", func(t *testing.T) { @@ -500,6 +520,14 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, "by-dev-dml1", Params.RootCoordDml.GetValue()) }) + + t.Run("clustering compaction config", func(t *testing.T) { + Params := ¶ms.CommonCfg + params.Save("common.usePartitionKeyAsClusteringKey", "true") + assert.Equal(t, true, Params.UsePartitionKeyAsClusteringKey.GetAsBool()) + params.Save("common.useVectorAsClusteringKey", "true") + assert.Equal(t, true, Params.UseVectorAsClusteringKey.GetAsBool()) + }) } func TestForbiddenItem(t *testing.T) { diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index dfa35f2109..ae56f64007 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -378,6 +378,20 @@ func IsDenseFloatVectorType(dataType schemapb.DataType) bool { } } +// return VectorTypeSize for each dim (byte) +func VectorTypeSize(dataType schemapb.DataType) float64 { + switch dataType { + case schemapb.DataType_FloatVector, schemapb.DataType_SparseFloatVector: + return 4.0 + case schemapb.DataType_BinaryVector: + return 0.125 + case schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: + return 2.0 + default: + return 0.0 + } +} + func IsSparseFloatVectorType(dataType schemapb.DataType) bool { return dataType == schemapb.DataType_SparseFloatVector } @@ -1085,6 +1099,16 @@ func IsFieldDataTypeSupportMaterializedView(fieldSchema *schemapb.FieldSchema) b return fieldSchema.DataType == schemapb.DataType_VarChar || fieldSchema.DataType == schemapb.DataType_String } +// HasClusterKey check if a collection schema has ClusterKey field +func HasClusterKey(schema *schemapb.CollectionSchema) bool { + for _, fieldSchema := range schema.Fields { + if fieldSchema.IsClusteringKey { + return true + } + } + return false +} + // GetPrimaryFieldData get primary field data from all field data inserted from sdk func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schemapb.FieldSchema) (*schemapb.FieldData, error) { primaryFieldID := primaryFieldSchema.FieldID diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index 6e6a6ec698..99afb89152 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -1047,6 +1047,40 @@ func TestGetPrimaryFieldSchema(t *testing.T) { assert.True(t, hasPartitionKey2) } +func TestGetClusterKeyFieldSchema(t *testing.T) { + int64Field := &schemapb.FieldSchema{ + FieldID: 1, + Name: "int64Field", + DataType: schemapb.DataType_Int64, + } + + clusterKeyfloatField := &schemapb.FieldSchema{ + FieldID: 2, + Name: "floatField", + DataType: schemapb.DataType_Float, + IsClusteringKey: true, + } + + unClusterKeyfloatField := &schemapb.FieldSchema{ + FieldID: 2, + Name: "floatField", + DataType: schemapb.DataType_Float, + IsClusteringKey: false, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{int64Field, clusterKeyfloatField}, + } + schema2 := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{int64Field, unClusterKeyfloatField}, + } + + hasClusterKey1 := HasClusterKey(schema) + assert.True(t, hasClusterKey1) + hasClusterKey2 := HasClusterKey(schema2) + assert.False(t, hasClusterKey2) +} + func TestGetPK(t *testing.T) { type args struct { data *schemapb.IDs diff --git a/scripts/generate_proto.sh b/scripts/generate_proto.sh index 286570b842..5b92bef12e 100755 --- a/scripts/generate_proto.sh +++ b/scripts/generate_proto.sh @@ -50,6 +50,7 @@ mkdir -p internalpb mkdir -p rootcoordpb mkdir -p segcorepb +mkdir -p clusteringpb mkdir -p proxypb mkdir -p indexpb @@ -72,6 +73,7 @@ ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./datapb data_coord.pr ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./querypb query_coord.proto|| { echo 'generate query_coord.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./planpb plan.proto|| { echo 'generate plan.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./segcorepb segcore.proto|| { echo 'generate segcore.proto failed'; exit 1; } +${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./clusteringpb clustering.proto|| { echo 'generate clustering.proto failed'; exit 1; } ${protoc_opt} --proto_path=$ROOT_DIR/cmd/tools/migration/legacy/ \ --go_out=plugins=grpc,paths=source_relative:../../cmd/tools/migration/legacy/legacypb legacy.proto || { echo 'generate legacy.proto failed'; exit 1; } @@ -79,10 +81,9 @@ ${protoc_opt} --proto_path=$ROOT_DIR/cmd/tools/migration/legacy/ \ ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb schema.proto|| { echo 'generate schema.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb common.proto|| { echo 'generate common.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb segcore.proto|| { echo 'generate segcore.proto failed'; exit 1; } +${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb clustering.proto|| { echo 'generate clustering.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb index_cgo_msg.proto|| { echo 'generate index_cgo_msg.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb cgo_msg.proto|| { echo 'generate cgo_msg.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb plan.proto|| { echo 'generate plan.proto failed'; exit 1; } popd - - diff --git a/tests/go_client/go.mod b/tests/go_client/go.mod index 665fdcc11f..6241b69c92 100644 --- a/tests/go_client/go.mod +++ b/tests/go_client/go.mod @@ -1,6 +1,8 @@ module github.com/milvus-io/milvus/tests/go_client -go 1.20 +go 1.21 + +toolchain go1.21.10 require ( github.com/milvus-io/milvus/client/v2 v2.0.0-20240521081339-017fd7bc25de