enhance: Support analyze data (#33651)

issue: #30633

Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
Co-authored-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
cai.zhang 2024-06-06 17:37:51 +08:00 committed by GitHub
parent cfea3f43cf
commit 27cc9f2630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
82 changed files with 7942 additions and 2930 deletions

View File

@ -465,6 +465,7 @@ generate-mockery-datacoord: getdeps
$(INSTALL_PATH)/mockery --name=ChannelManager --dir=internal/datacoord --filename=mock_channelmanager.go --output=internal/datacoord --structname=MockChannelManager --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=SubCluster --dir=internal/datacoord --filename=mock_subcluster.go --output=internal/datacoord --structname=MockSubCluster --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=Broker --dir=internal/datacoord/broker --filename=mock_coordinator_broker.go --output=internal/datacoord/broker --structname=MockBroker --with-expecter --inpackage
$(INSTALL_PATH)/mockery --name=WorkerManager --dir=internal/datacoord --filename=mock_worker_manager.go --output=internal/datacoord --structname=MockWorkerManager --with-expecter --inpackage
generate-mockery-datanode: getdeps
$(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage

View File

@ -448,6 +448,30 @@ dataCoord:
rpcTimeout: 10
maxParallelTaskNum: 10
workerMaxParallelTaskNum: 2
clustering:
enable: true
autoEnable: false
triggerInterval: 600
stateCheckInterval: 10
gcInterval: 600
minInterval: 3600
maxInterval: 259200
newDataRatioThreshold: 0.2
newDataSizeThreshold: 512m
timeout: 7200
dropTolerance: 86400
# clustering compaction will try best to distribute data into segments with size range in [preferSegmentSize, maxSegmentSize].
# data will be clustered by preferSegmentSize, if a cluster is larger than maxSegmentSize, will spilt it into multi segment
# buffer between (preferSegmentSize, maxSegmentSize) is left for new data in the same cluster(range), to avoid globally redistribute too often
preferSegmentSize: 512m
maxSegmentSize: 1024m
maxTrainSizeRatio: 0.8 # max data size ratio in analyze, if data is larger than it, will down sampling to meet this limit
maxCentroidsNum: 10240
minCentroidsNum: 16
minClusterSizeRatio: 0.01
maxClusterSizeRatio: 10
maxClusterSize: 5g
levelzero:
forceTrigger:
minSize: 8388608 # The minmum size in bytes to force trigger a LevelZero Compaction, default as 8MB
@ -535,6 +559,9 @@ dataNode:
slot:
slotCap: 2 # The maximum number of tasks(e.g. compaction, importing) allowed to run concurrently on a datanode.
clusteringCompaction:
memoryBufferRatio: 0.1 # The ratio of memory buffer of clustering compaction. Data larger than threshold will be spilled to storage.
# Configures the system log output.
log:
level: info # Only supports debug, info, warn, error, panic, or fatal. Default 'info'.
@ -615,6 +642,8 @@ common:
traceLogMode: 0 # trace request info
bloomFilterSize: 100000 # bloom filter initial size
maxBloomFalsePositive: 0.001 # max false positive rate for bloom filter
usePartitionKeyAsClusteringKey: false
useVectorAsClusteringKey: false
# QuotaConfig, configurations of Milvus quota and limits.
# By default, we enable:

View File

@ -293,6 +293,12 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/indexbuilder/
FILES_MATCHING PATTERN "*_c.h"
)
# Install clustering
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/clustering/
DESTINATION include/clustering
FILES_MATCHING PATTERN "*_c.h"
)
# Install common
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/common/
DESTINATION include/common

View File

@ -32,5 +32,6 @@ add_subdirectory( index )
add_subdirectory( query )
add_subdirectory( segcore )
add_subdirectory( indexbuilder )
add_subdirectory( clustering )
add_subdirectory( exec )
add_subdirectory( bitset )

View File

@ -0,0 +1,24 @@
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License
set(CLUSTERING_FILES
analyze_c.cpp
KmeansClustering.cpp
)
milvus_add_pkg_config("milvus_clustering")
add_library(milvus_clustering SHARED ${CLUSTERING_FILES})
# link order matters
target_link_libraries(milvus_clustering milvus_index)
install(TARGETS milvus_clustering DESTINATION "${CMAKE_INSTALL_LIBDIR}")

View File

@ -0,0 +1,532 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "index/VectorDiskIndex.h"
#include "common/Tracer.h"
#include "common/Utils.h"
#include "config/ConfigKnowhere.h"
#include "index/Meta.h"
#include "index/Utils.h"
#include "knowhere/cluster/cluster_factory.h"
#include "knowhere/comp/time_recorder.h"
#include "clustering/KmeansClustering.h"
#include "segcore/SegcoreConfig.h"
#include "storage/LocalChunkManagerSingleton.h"
#include "storage/Util.h"
#include "common/Consts.h"
#include "common/RangeSearchHelper.h"
#include "clustering/types.h"
#include "clustering/file_utils.h"
#include <random>
namespace milvus::clustering {
KmeansClustering::KmeansClustering(
const storage::FileManagerContext& file_manager_context) {
file_manager_ =
std::make_unique<storage::MemFileManagerImpl>(file_manager_context);
AssertInfo(file_manager_ != nullptr, "create file manager failed!");
int64_t collection_id = file_manager_context.fieldDataMeta.collection_id;
int64_t partition_id = file_manager_context.fieldDataMeta.partition_id;
msg_header_ = fmt::format(
"collection: {}, partition: {} ", collection_id, partition_id);
}
template <typename T>
void
KmeansClustering::FetchDataFiles(uint8_t* buf,
const int64_t expected_train_size,
const int64_t expected_remote_file_size,
const std::vector<std::string>& files,
const int64_t dim,
int64_t& offset) {
// CacheRawDataToMemory mostly used as pull files from one segment
// So we could assume memory is always enough for theses cases
// But in clustering when we sample train data, first pre-allocate the large buffer(size controlled by config) for future knowhere usage
// And we will have tmp memory usage at pulling stage, pull file(tmp memory) + memcpy to pre-allocated buffer, limit the batch here
auto batch = size_t(DEFAULT_FIELD_MAX_MEMORY_LIMIT / FILE_SLICE_SIZE);
int64_t fetched_file_size = 0;
for (size_t i = 0; i < files.size(); i += batch) {
size_t start = i;
size_t end = std::min(files.size(), i + batch);
std::vector<std::string> group_files(files.begin() + start,
files.begin() + end);
auto field_datas = file_manager_->CacheRawDataToMemory(group_files);
for (auto& data : field_datas) {
size_t size = std::min(expected_train_size - offset, data->Size());
if (size <= 0) {
break;
}
fetched_file_size += size;
std::memcpy(buf + offset, data->Data(), size);
offset += size;
data.reset();
}
}
AssertInfo(fetched_file_size == expected_remote_file_size,
"file size inconsistent, expected: {}, actual: {}",
expected_remote_file_size,
fetched_file_size);
}
template <typename T>
void
KmeansClustering::SampleTrainData(
const std::vector<int64_t>& segment_ids,
const std::map<int64_t, std::vector<std::string>>& segment_file_paths,
const std::map<int64_t, int64_t>& segment_num_rows,
const int64_t expected_train_size,
const int64_t dim,
const bool random_sample,
uint8_t* buf) {
int64_t offset = 0;
std::vector<std::string> files;
if (random_sample) {
for (auto& [segment_id, segment_files] : segment_file_paths) {
for (auto& segment_file : segment_files) {
files.emplace_back(segment_file);
}
}
// shuffle files
std::shuffle(files.begin(), files.end(), std::mt19937());
FetchDataFiles<T>(
buf, expected_train_size, expected_train_size, files, dim, offset);
return;
}
// pick all segment_ids, no shuffle
// and pull data once each segment to reuse the id mapping for assign stage
for (auto i = 0; i < segment_ids.size(); i++) {
if (offset == expected_train_size) {
break;
}
int64_t cur_segment_id = segment_ids[i];
files = segment_file_paths.at(cur_segment_id);
std::sort(files.begin(),
files.end(),
[](const std::string& a, const std::string& b) {
return std::stol(a.substr(a.find_last_of("/") + 1)) <
std::stol(b.substr(b.find_last_of("/") + 1));
});
FetchDataFiles<T>(buf,
expected_train_size,
segment_num_rows.at(cur_segment_id) * dim * sizeof(T),
files,
dim,
offset);
}
}
template <typename T>
milvus::proto::clustering::ClusteringCentroidsStats
KmeansClustering::CentroidsToPB(const T* centroids,
const int64_t num_clusters,
const int64_t dim) {
milvus::proto::clustering::ClusteringCentroidsStats stats;
for (auto i = 0; i < num_clusters; i++) {
milvus::proto::schema::VectorField* vector_field =
stats.add_centroids();
vector_field->set_dim(dim);
milvus::proto::schema::FloatArray* float_array =
vector_field->mutable_float_vector();
for (auto j = 0; j < dim; j++) {
float_array->add_data(float(centroids[i * dim + j]));
}
}
return stats;
}
std::vector<milvus::proto::clustering::ClusteringCentroidIdMappingStats>
KmeansClustering::CentroidIdMappingToPB(
const uint32_t* centroid_id_mapping,
const std::vector<int64_t>& segment_ids,
const int64_t trained_segments_num,
const std::map<int64_t, int64_t>& num_row_map,
const int64_t num_clusters) {
auto compute_num_in_centroid = [&](const uint32_t* centroid_id_mapping,
uint64_t start,
uint64_t end) -> std::vector<int64_t> {
std::vector<int64_t> num_vectors(num_clusters, 0);
for (uint64_t i = start; i < end; ++i) {
num_vectors[centroid_id_mapping[i]]++;
}
return num_vectors;
};
std::vector<milvus::proto::clustering::ClusteringCentroidIdMappingStats>
stats_arr;
stats_arr.reserve(trained_segments_num);
int64_t cur_offset = 0;
for (auto i = 0; i < trained_segments_num; i++) {
milvus::proto::clustering::ClusteringCentroidIdMappingStats stats;
auto num_offset = num_row_map.at(segment_ids[i]);
for (auto j = 0; j < num_offset; j++) {
stats.add_centroid_id_mapping(centroid_id_mapping[cur_offset + j]);
}
auto num_vectors = compute_num_in_centroid(
centroid_id_mapping, cur_offset, cur_offset + num_offset);
for (uint64_t j = 0; j < num_clusters; j++) {
stats.add_num_in_centroid(num_vectors[j]);
}
cur_offset += num_offset;
stats_arr.emplace_back(stats);
}
return stats_arr;
}
template <typename T>
bool
KmeansClustering::IsDataSkew(
const milvus::proto::clustering::AnalyzeInfo& config,
const int64_t dim,
std::vector<int64_t>& num_in_each_centroid) {
auto min_cluster_ratio = config.min_cluster_ratio();
auto max_cluster_ratio = config.max_cluster_ratio();
auto max_cluster_size = config.max_cluster_size();
std::sort(num_in_each_centroid.begin(), num_in_each_centroid.end());
size_t avg_size =
std::accumulate(
num_in_each_centroid.begin(), num_in_each_centroid.end(), 0) /
(num_in_each_centroid.size());
if (num_in_each_centroid.front() <= min_cluster_ratio * avg_size) {
LOG_INFO(msg_header_ + "minimum cluster too small: {}, avg: {}",
num_in_each_centroid.front(),
avg_size);
return true;
}
if (num_in_each_centroid.back() >= max_cluster_ratio * avg_size) {
LOG_INFO(msg_header_ + "maximum cluster too large: {}, avg: {}",
num_in_each_centroid.back(),
avg_size);
return true;
}
if (num_in_each_centroid.back() * dim * sizeof(T) >= max_cluster_size) {
LOG_INFO(msg_header_ + "maximum cluster size too large: {}B",
num_in_each_centroid.back() * dim * sizeof(T));
return true;
}
return false;
}
template <typename T>
void
KmeansClustering::StreamingAssignandUpload(
knowhere::Cluster<knowhere::ClusterNode>& cluster_node,
const milvus::proto::clustering::AnalyzeInfo& config,
const milvus::proto::clustering::ClusteringCentroidsStats& centroid_stats,
const std::vector<
milvus::proto::clustering::ClusteringCentroidIdMappingStats>&
id_mapping_stats,
const std::vector<int64_t>& segment_ids,
const std::map<int64_t, std::vector<std::string>>& insert_files,
const std::map<int64_t, int64_t>& num_rows,
const int64_t dim,
const int64_t trained_segments_num,
const int64_t num_clusters) {
auto byte_size = centroid_stats.ByteSizeLong();
std::unique_ptr<uint8_t[]> data = std::make_unique<uint8_t[]>(byte_size);
centroid_stats.SerializeToArray(data.get(), byte_size);
std::unordered_map<std::string, int64_t> remote_paths_to_size;
LOG_INFO(msg_header_ + "start upload cluster centroids file");
AddClusteringResultFiles(
file_manager_->GetChunkManager().get(),
data.get(),
byte_size,
GetRemoteCentroidsObjectPrefix() + "/" + std::string(CENTROIDS_NAME),
remote_paths_to_size);
cluster_result_.centroid_path =
GetRemoteCentroidsObjectPrefix() + "/" + std::string(CENTROIDS_NAME);
cluster_result_.centroid_file_size =
remote_paths_to_size.at(cluster_result_.centroid_path);
remote_paths_to_size.clear();
LOG_INFO(msg_header_ + "upload cluster centroids file done");
LOG_INFO(msg_header_ + "start upload cluster id mapping file");
std::vector<int64_t> num_vectors_each_centroid(num_clusters, 0);
auto serializeIdMappingAndUpload = [&](const int64_t segment_id,
const milvus::proto::clustering::
ClusteringCentroidIdMappingStats&
id_mapping_pb) {
auto byte_size = id_mapping_pb.ByteSizeLong();
std::unique_ptr<uint8_t[]> data =
std::make_unique<uint8_t[]>(byte_size);
id_mapping_pb.SerializeToArray(data.get(), byte_size);
AddClusteringResultFiles(
file_manager_->GetChunkManager().get(),
data.get(),
byte_size,
GetRemoteCentroidIdMappingObjectPrefix(segment_id) + "/" +
std::string(OFFSET_MAPPING_NAME),
remote_paths_to_size);
LOG_INFO(
msg_header_ +
"upload segment {} cluster id mapping file with size {} B done",
segment_id,
byte_size);
};
for (size_t i = 0; i < segment_ids.size(); i++) {
int64_t segment_id = segment_ids[i];
// id mapping has been computed, just upload to remote
if (i < trained_segments_num) {
serializeIdMappingAndUpload(segment_id, id_mapping_stats[i]);
for (int64_t j = 0; j < num_clusters; ++j) {
num_vectors_each_centroid[j] +=
id_mapping_stats[i].num_in_centroid(j);
}
} else { // streaming download raw data, assign id mapping, then upload
int64_t num_row = num_rows.at(segment_id);
std::unique_ptr<T[]> buf = std::make_unique<T[]>(num_row * dim);
int64_t offset = 0;
FetchDataFiles<T>(reinterpret_cast<uint8_t*>(buf.get()),
INT64_MAX,
num_row * dim * sizeof(T),
insert_files.at(segment_id),
dim,
offset);
auto dataset = GenDataset(num_row, dim, buf.release());
dataset->SetIsOwner(true);
auto res = cluster_node.Assign(*dataset);
if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
fmt::format("failed to kmeans assign: {}: {}",
KnowhereStatusString(res.error()),
res.what()));
}
res.value()->SetIsOwner(true);
auto id_mapping =
reinterpret_cast<const uint32_t*>(res.value()->GetTensor());
auto id_mapping_pb = CentroidIdMappingToPB(
id_mapping, {segment_id}, 1, num_rows, num_clusters)[0];
for (int64_t j = 0; j < num_clusters; ++j) {
num_vectors_each_centroid[j] +=
id_mapping_pb.num_in_centroid(j);
}
serializeIdMappingAndUpload(segment_id, id_mapping_pb);
}
}
if (IsDataSkew<T>(config, dim, num_vectors_each_centroid)) {
LOG_INFO(msg_header_ + "data skew! skip clustering");
// remove uploaded files
RemoveClusteringResultFiles(file_manager_->GetChunkManager().get(),
remote_paths_to_size);
// skip clustering, nothing takes affect
throw SegcoreError(ErrorCode::ClusterSkip,
"data skew! skip clustering");
}
LOG_INFO(msg_header_ + "upload cluster id mapping file done");
cluster_result_.id_mappings = std::move(remote_paths_to_size);
is_runned_ = true;
}
template <typename T>
void
KmeansClustering::Run(const milvus::proto::clustering::AnalyzeInfo& config) {
std::map<int64_t, std::vector<std::string>> insert_files;
for (const auto& pair : config.insert_files()) {
std::vector<std::string> segment_files(
pair.second.insert_files().begin(),
pair.second.insert_files().end());
insert_files[pair.first] = segment_files;
}
std::map<int64_t, int64_t> num_rows(config.num_rows().begin(),
config.num_rows().end());
auto num_clusters = config.num_clusters();
AssertInfo(num_clusters > 0, "num clusters must larger than 0");
auto train_size = config.train_size();
AssertInfo(train_size > 0, "train size must larger than 0");
auto dim = config.dim();
auto min_cluster_ratio = config.min_cluster_ratio();
AssertInfo(min_cluster_ratio > 0 && min_cluster_ratio < 1,
"min cluster ratio must larger than 0, less than 1");
auto max_cluster_ratio = config.max_cluster_ratio();
AssertInfo(max_cluster_ratio > 1, "max cluster ratio must larger than 1");
auto max_cluster_size = config.max_cluster_size();
AssertInfo(max_cluster_size > 0, "max cluster size must larger than 0");
auto cluster_node_obj =
knowhere::ClusterFactory::Instance().Create<T>(KMEANS_CLUSTER);
knowhere::Cluster<knowhere::ClusterNode> cluster_node;
if (cluster_node_obj.has_value()) {
cluster_node = std::move(cluster_node_obj.value());
} else {
auto err = cluster_node_obj.error();
if (err == knowhere::Status::invalid_cluster_error) {
throw SegcoreError(ErrorCode::ClusterSkip, cluster_node_obj.what());
}
throw SegcoreError(ErrorCode::KnowhereError, cluster_node_obj.what());
}
size_t data_num = 0;
std::vector<int64_t> segment_ids;
for (auto& [segment_id, num_row_each_segment] : num_rows) {
data_num += num_row_each_segment;
segment_ids.emplace_back(segment_id);
AssertInfo(insert_files.find(segment_id) != insert_files.end(),
"segment id {} not exist in insert files",
segment_id);
}
size_t trained_segments_num = 0;
size_t data_size = data_num * dim * sizeof(T);
size_t train_num = train_size / sizeof(T) / dim;
bool random_sample = true;
// make train num equal to data num
if (train_num >= data_num) {
train_num = data_num;
random_sample =
false; // all data are used for training, no need to random sampling
trained_segments_num = segment_ids.size();
}
if (train_num < num_clusters) {
LOG_WARN(msg_header_ +
"kmeans train num: {} less than num_clusters: {}, skip "
"clustering",
train_num,
num_clusters);
throw SegcoreError(ErrorCode::ClusterSkip,
"sample data num less than num clusters");
}
size_t train_size_final = train_num * dim * sizeof(T);
knowhere::TimeRecorder rc(msg_header_ + "kmeans clustering",
2 /* log level: info */);
// if data_num larger than max_train_size, we need to sample to make train data fits in memory
// otherwise just load all the data for kmeans training
LOG_INFO(msg_header_ + "pull and sample {}GB data out of {}GB data",
train_size_final / 1024.0 / 1024.0 / 1024.0,
data_size / 1024.0 / 1024.0 / 1024.0);
auto buf = std::make_unique<uint8_t[]>(train_size_final);
SampleTrainData<T>(segment_ids,
insert_files,
num_rows,
train_size_final,
dim,
random_sample,
buf.get());
rc.RecordSection("sample done");
auto dataset = GenDataset(train_num, dim, buf.release());
dataset->SetIsOwner(true);
LOG_INFO(msg_header_ + "train data num: {}, dim: {}, num_clusters: {}",
train_num,
dim,
num_clusters);
knowhere::Json train_conf;
train_conf[NUM_CLUSTERS] = num_clusters;
// inside knowhere, we will record each kmeans iteration duration
// return id mapping
auto res = cluster_node.Train(*dataset, train_conf);
if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
fmt::format("failed to kmeans train: {}: {}",
KnowhereStatusString(res.error()),
res.what()));
}
res.value()->SetIsOwner(true);
rc.RecordSection("clustering train done");
dataset.reset(); // release train data
auto centroid_id_mapping =
reinterpret_cast<const uint32_t*>(res.value()->GetTensor());
auto centroids_res = cluster_node.GetCentroids();
if (!centroids_res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
fmt::format("failed to get centroids: {}: {}",
KnowhereStatusString(res.error()),
res.what()));
}
// centroids owned by cluster_node
centroids_res.value()->SetIsOwner(false);
auto centroids =
reinterpret_cast<const T*>(centroids_res.value()->GetTensor());
auto centroid_stats = CentroidsToPB<T>(centroids, num_clusters, dim);
auto id_mapping_stats = CentroidIdMappingToPB(centroid_id_mapping,
segment_ids,
trained_segments_num,
num_rows,
num_clusters);
// upload
StreamingAssignandUpload<T>(cluster_node,
config,
centroid_stats,
id_mapping_stats,
segment_ids,
insert_files,
num_rows,
dim,
trained_segments_num,
num_clusters);
rc.RecordSection("clustering result upload done");
rc.ElapseFromBegin("clustering done");
}
template void
KmeansClustering::StreamingAssignandUpload<float>(
knowhere::Cluster<knowhere::ClusterNode>& cluster_node,
const milvus::proto::clustering::AnalyzeInfo& config,
const milvus::proto::clustering::ClusteringCentroidsStats& centroid_stats,
const std::vector<
milvus::proto::clustering::ClusteringCentroidIdMappingStats>&
id_mapping_stats,
const std::vector<int64_t>& segment_ids,
const std::map<int64_t, std::vector<std::string>>& insert_files,
const std::map<int64_t, int64_t>& num_rows,
const int64_t dim,
const int64_t trained_segments_num,
const int64_t num_clusters);
template void
KmeansClustering::FetchDataFiles<float>(uint8_t* buf,
const int64_t expected_train_size,
const int64_t expected_remote_file_size,
const std::vector<std::string>& files,
const int64_t dim,
int64_t& offset);
template void
KmeansClustering::SampleTrainData<float>(
const std::vector<int64_t>& segment_ids,
const std::map<int64_t, std::vector<std::string>>& segment_file_paths,
const std::map<int64_t, int64_t>& segment_num_rows,
const int64_t expected_train_size,
const int64_t dim,
const bool random_sample,
uint8_t* buf);
template void
KmeansClustering::Run<float>(
const milvus::proto::clustering::AnalyzeInfo& config);
template milvus::proto::clustering::ClusteringCentroidsStats
KmeansClustering::CentroidsToPB<float>(const float* centroids,
const int64_t num_clusters,
const int64_t dim);
template bool
KmeansClustering::IsDataSkew<float>(
const milvus::proto::clustering::AnalyzeInfo& config,
const int64_t dim,
std::vector<int64_t>& num_in_each_centroid);
} // namespace milvus::clustering

View File

@ -0,0 +1,157 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <vector>
#include "storage/MemFileManagerImpl.h"
#include "storage/space.h"
#include "pb/clustering.pb.h"
#include "knowhere/cluster/cluster_factory.h"
namespace milvus::clustering {
// after clustering result uploaded, return result meta for golang usage
struct ClusteringResultMeta {
std::string centroid_path; // centroid result path
int64_t centroid_file_size; // centroid result size
std::unordered_map<std::string, int64_t>
id_mappings; // id mapping result path/size for each segment
};
class KmeansClustering {
public:
explicit KmeansClustering(
const storage::FileManagerContext& file_manager_context =
storage::FileManagerContext());
// every time is a brand new kmeans training
template <typename T>
void
Run(const milvus::proto::clustering::AnalyzeInfo& config);
// should never be called before run
ClusteringResultMeta
GetClusteringResultMeta() {
if (!is_runned_) {
throw SegcoreError(
ErrorCode::UnexpectedError,
"clustering result is not ready before kmeans run");
}
return cluster_result_;
}
// ut
inline std::string
GetRemoteCentroidsObjectPrefix() const {
auto index_meta_ = file_manager_->GetIndexMeta();
auto field_meta_ = file_manager_->GetFieldDataMeta();
return file_manager_->GetChunkManager()->GetRootPath() + "/" +
std::string(ANALYZE_ROOT_PATH) + "/" +
std::to_string(index_meta_.build_id) + "/" +
std::to_string(index_meta_.index_version) + "/" +
std::to_string(field_meta_.collection_id) + "/" +
std::to_string(field_meta_.partition_id) + "/" +
std::to_string(field_meta_.field_id);
}
inline std::string
GetRemoteCentroidIdMappingObjectPrefix(int64_t segment_id) const {
auto index_meta_ = file_manager_->GetIndexMeta();
auto field_meta_ = file_manager_->GetFieldDataMeta();
return file_manager_->GetChunkManager()->GetRootPath() + "/" +
std::string(ANALYZE_ROOT_PATH) + "/" +
std::to_string(index_meta_.build_id) + "/" +
std::to_string(index_meta_.index_version) + "/" +
std::to_string(field_meta_.collection_id) + "/" +
std::to_string(field_meta_.partition_id) + "/" +
std::to_string(field_meta_.field_id) + "/" +
std::to_string(segment_id);
}
~KmeansClustering() = default;
private:
template <typename T>
void
StreamingAssignandUpload(
knowhere::Cluster<knowhere::ClusterNode>& cluster_node,
const milvus::proto::clustering::AnalyzeInfo& config,
const milvus::proto::clustering::ClusteringCentroidsStats&
centroid_stats,
const std::vector<
milvus::proto::clustering::ClusteringCentroidIdMappingStats>&
id_mapping_stats,
const std::vector<int64_t>& segment_ids,
const std::map<int64_t, std::vector<std::string>>& insert_files,
const std::map<int64_t, int64_t>& num_rows,
const int64_t dim,
const int64_t trained_segments_num,
const int64_t num_clusters);
template <typename T>
void
FetchDataFiles(uint8_t* buf,
const int64_t expected_train_size,
const int64_t expected_remote_file_size,
const std::vector<std::string>& files,
const int64_t dim,
int64_t& offset);
// given all possible segments, sample data to buffer
template <typename T>
void
SampleTrainData(
const std::vector<int64_t>& segment_ids,
const std::map<int64_t, std::vector<std::string>>& segment_file_paths,
const std::map<int64_t, int64_t>& segment_num_rows,
const int64_t expected_train_size,
const int64_t dim,
const bool random_sample,
uint8_t* buf);
// transform centroids result to PB format for future usage of golang side
template <typename T>
milvus::proto::clustering::ClusteringCentroidsStats
CentroidsToPB(const T* centroids,
const int64_t num_clusters,
const int64_t dim);
// transform flattened id mapping result to several PB files by each segment for future usage of golang side
std::vector<milvus::proto::clustering::ClusteringCentroidIdMappingStats>
CentroidIdMappingToPB(const uint32_t* centroid_id_mapping,
const std::vector<int64_t>& segment_ids,
const int64_t trained_segments_num,
const std::map<int64_t, int64_t>& num_row_map,
const int64_t num_clusters);
template <typename T>
bool
IsDataSkew(const milvus::proto::clustering::AnalyzeInfo& config,
const int64_t dim,
std::vector<int64_t>& num_in_each_centroid);
std::unique_ptr<storage::MemFileManagerImpl> file_manager_;
ClusteringResultMeta cluster_result_;
bool is_runned_ = false;
std::string msg_header_;
};
using KmeansClusteringPtr = std::unique_ptr<KmeansClustering>;
} // namespace milvus::clustering

View File

@ -0,0 +1,157 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <memory>
#ifdef __linux__
#include <malloc.h>
#endif
#include "analyze_c.h"
#include "common/type_c.h"
#include "type_c.h"
#include "types.h"
#include "index/Utils.h"
#include "index/Meta.h"
#include "storage/Util.h"
#include "pb/clustering.pb.h"
#include "clustering/KmeansClustering.h"
using namespace milvus;
milvus::storage::StorageConfig
get_storage_config(const milvus::proto::clustering::StorageConfig& config) {
auto storage_config = milvus::storage::StorageConfig();
storage_config.address = std::string(config.address());
storage_config.bucket_name = std::string(config.bucket_name());
storage_config.access_key_id = std::string(config.access_keyid());
storage_config.access_key_value = std::string(config.secret_access_key());
storage_config.root_path = std::string(config.root_path());
storage_config.storage_type = std::string(config.storage_type());
storage_config.cloud_provider = std::string(config.cloud_provider());
storage_config.iam_endpoint = std::string(config.iamendpoint());
storage_config.cloud_provider = std::string(config.cloud_provider());
storage_config.useSSL = config.usessl();
storage_config.sslCACert = config.sslcacert();
storage_config.useIAM = config.useiam();
storage_config.region = config.region();
storage_config.useVirtualHost = config.use_virtual_host();
storage_config.requestTimeoutMs = config.request_timeout_ms();
return storage_config;
}
CStatus
Analyze(CAnalyze* res_analyze,
const uint8_t* serialized_analyze_info,
const uint64_t len) {
try {
auto analyze_info =
std::make_unique<milvus::proto::clustering::AnalyzeInfo>();
auto res = analyze_info->ParseFromArray(serialized_analyze_info, len);
AssertInfo(res, "Unmarshall analyze info failed");
auto field_type =
static_cast<DataType>(analyze_info->field_schema().data_type());
auto field_id = analyze_info->field_schema().fieldid();
// init file manager
milvus::storage::FieldDataMeta field_meta{analyze_info->collectionid(),
analyze_info->partitionid(),
0,
field_id};
milvus::storage::IndexMeta index_meta{
0, field_id, analyze_info->buildid(), analyze_info->version()};
auto storage_config =
get_storage_config(analyze_info->storage_config());
auto chunk_manager =
milvus::storage::CreateChunkManager(storage_config);
milvus::storage::FileManagerContext fileManagerContext(
field_meta, index_meta, chunk_manager);
if (field_type != DataType::VECTOR_FLOAT) {
throw SegcoreError(
DataTypeInvalid,
fmt::format("invalid data type for clustering is {}",
std::to_string(int(field_type))));
}
auto clusteringJob =
std::make_unique<milvus::clustering::KmeansClustering>(
fileManagerContext);
clusteringJob->Run<float>(*analyze_info);
*res_analyze = clusteringJob.release();
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
return status;
} catch (SegcoreError& e) {
auto status = CStatus();
status.error_code = e.get_error_code();
status.error_msg = strdup(e.what());
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
return status;
}
}
CStatus
DeleteAnalyze(CAnalyze analyze) {
auto status = CStatus();
try {
AssertInfo(analyze, "failed to delete analyze, passed index was null");
auto real_analyze =
reinterpret_cast<milvus::clustering::KmeansClustering*>(analyze);
delete real_analyze;
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}
CStatus
GetAnalyzeResultMeta(CAnalyze analyze,
char** centroid_path,
int64_t* centroid_file_size,
void* id_mapping_paths,
int64_t* id_mapping_sizes) {
auto status = CStatus();
try {
AssertInfo(analyze,
"failed to serialize analyze to binary set, passed index "
"was null");
auto real_analyze =
reinterpret_cast<milvus::clustering::KmeansClustering*>(analyze);
auto res = real_analyze->GetClusteringResultMeta();
*centroid_path = res.centroid_path.data();
*centroid_file_size = res.centroid_file_size;
auto& map_ = res.id_mappings;
const char** id_mapping_paths_ = (const char**)id_mapping_paths;
size_t i = 0;
for (auto it = map_.begin(); it != map_.end(); ++it, i++) {
id_mapping_paths_[i] = it->first.data();
id_mapping_sizes[i] = it->second;
}
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}

View File

@ -0,0 +1,40 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include "common/type_c.h"
#include "common/binary_set_c.h"
#include "clustering/type_c.h"
CStatus
Analyze(CAnalyze* res_analyze,
const uint8_t* serialized_analyze_info,
const uint64_t len);
CStatus
DeleteAnalyze(CAnalyze analyze);
CStatus
GetAnalyzeResultMeta(CAnalyze analyze,
char** centroid_path,
int64_t* centroid_file_size,
void* id_mapping_paths,
int64_t* id_mapping_sizes);
#ifdef __cplusplus
};
#endif

View File

@ -0,0 +1,69 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "common/type_c.h"
#include <future>
#include "storage/ThreadPools.h"
#include "common/FieldData.h"
#include "common/LoadInfo.h"
#include "knowhere/comp/index_param.h"
#include "parquet/schema.h"
#include "storage/PayloadStream.h"
#include "storage/FileManager.h"
#include "storage/BinlogReader.h"
#include "storage/ChunkManager.h"
#include "storage/DataCodec.h"
#include "storage/Types.h"
#include "storage/space.h"
namespace milvus::clustering {
void
AddClusteringResultFiles(milvus::storage::ChunkManager* remote_chunk_manager,
const uint8_t* data,
const int64_t data_size,
const std::string& remote_prefix,
std::unordered_map<std::string, int64_t>& map) {
remote_chunk_manager->Write(
remote_prefix, const_cast<uint8_t*>(data), data_size);
map[remote_prefix] = data_size;
}
void
RemoveClusteringResultFiles(
milvus::storage::ChunkManager* remote_chunk_manager,
const std::unordered_map<std::string, int64_t>& map) {
auto& pool = ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::MIDDLE);
std::vector<std::future<void>> futures;
for (auto& [file_path, file_size] : map) {
futures.push_back(pool.Submit(
[&, path = file_path]() { remote_chunk_manager->Remove(path); }));
}
std::exception_ptr first_exception = nullptr;
for (auto& future : futures) {
try {
future.get();
} catch (...) {
if (!first_exception) {
first_exception = std::current_exception();
}
}
}
if (first_exception) {
std::rethrow_exception(first_exception);
}
}
} // namespace milvus::clustering

View File

@ -0,0 +1,9 @@
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
Name: Milvus Clustering
Description: Clustering modules for Milvus
Version: @MILVUS_VERSION@
Libs: -L${libdir} -lmilvus_clustering
Cflags: -I${includedir}

View File

@ -0,0 +1,17 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "common/type_c.h"
typedef void* CAnalyze;
typedef void* CAnalyzeInfo;

View File

@ -0,0 +1,41 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <stdint.h>
#include <string>
#include <vector>
#include "common/Types.h"
#include "index/Index.h"
#include "storage/Types.h"
struct AnalyzeInfo {
int64_t collection_id;
int64_t partition_id;
int64_t field_id;
int64_t task_id;
int64_t version;
std::string field_name;
milvus::DataType field_type;
int64_t dim;
int64_t num_clusters;
int64_t train_size;
std::map<int64_t, std::vector<std::string>>
insert_files; // segment_id->files
std::map<int64_t, int64_t> num_rows;
milvus::storage::StorageConfig storage_config;
milvus::Config config;
};

View File

@ -38,6 +38,11 @@ const char INDEX_BUILD_ID_KEY[] = "indexBuildID";
const char INDEX_ROOT_PATH[] = "index_files";
const char RAWDATA_ROOT_PATH[] = "raw_datas";
const char ANALYZE_ROOT_PATH[] = "analyze_stats";
const char CENTROIDS_NAME[] = "centroids";
const char OFFSET_MAPPING_NAME[] = "offset_mapping";
const char NUM_CLUSTERS[] = "num_clusters";
const char KMEANS_CLUSTER[] = "KMEANS";
const char VEC_OPT_FIELDS[] = "opt_fields";
const char DEFAULT_PLANNODE_ID[] = "0";

View File

@ -60,6 +60,7 @@ enum ErrorCode {
UnistdError = 2030,
MetricTypeNotMatch = 2031,
DimNotMatch = 2032,
ClusterSkip = 2033,
KnowhereError = 2100,
};

View File

@ -14,6 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <stdint.h>
#include <string>
#include <vector>

View File

@ -15,7 +15,7 @@ file(GLOB_RECURSE milvus_proto_srcs
"${CMAKE_CURRENT_SOURCE_DIR}/*.cc")
add_library(milvus_proto STATIC
${milvus_proto_srcs}
)
)
message(STATUS "milvus proto sources: " ${milvus_proto_srcs})
target_link_libraries( milvus_proto PUBLIC ${CONAN_LIBS} )

View File

@ -442,6 +442,7 @@ SortByPath(std::vector<std::string>& paths) {
std::stol(b.substr(b.find_last_of("/") + 1));
});
}
template <typename DataType>
std::string
DiskFileManagerImpl::CacheRawDataToDisk(std::vector<std::string> remote_files) {

View File

@ -130,6 +130,11 @@ class FileManagerImpl : public knowhere::FileManager {
return index_meta_;
}
virtual ChunkManagerPtr
GetChunkManager() const {
return rcm_;
}
virtual std::string
GetRemoteIndexObjectPrefix() const {
return rcm_->GetRootPath() + "/" + std::string(INDEX_ROOT_PATH) + "/" +

View File

@ -69,6 +69,12 @@ set(MILVUS_TEST_FILES
test_regex_query.cpp
)
if ( INDEX_ENGINE STREQUAL "cardinal" )
set(MILVUS_TEST_FILES
${MILVUS_TEST_FILES}
test_kmeans_clustering.cpp)
endif()
if ( BUILD_DISK_ANN STREQUAL "ON" )
set(MILVUS_TEST_FILES
${MILVUS_TEST_FILES}
@ -121,6 +127,7 @@ if (LINUX)
milvus_segcore
milvus_storage
milvus_indexbuilder
milvus_clustering
milvus_common
)
install(TARGETS index_builder_test DESTINATION unittest)
@ -135,6 +142,7 @@ target_link_libraries(all_tests
milvus_segcore
milvus_storage
milvus_indexbuilder
milvus_clustering
pthread
milvus_common
milvus_exec

View File

@ -0,0 +1,321 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <functional>
#include <fstream>
#include <boost/filesystem.hpp>
#include <numeric>
#include <unordered_set>
#include "common/Tracer.h"
#include "common/EasyAssert.h"
#include "index/InvertedIndexTantivy.h"
#include "storage/Util.h"
#include "storage/InsertData.h"
#include "clustering/KmeansClustering.h"
#include "storage/LocalChunkManagerSingleton.h"
#include "test_utils/indexbuilder_test_utils.h"
#include "test_utils/storage_test_utils.h"
#include "index/Meta.h"
using namespace milvus;
void
ReadPBFile(std::string& file_path, google::protobuf::Message& message) {
std::ifstream infile;
infile.open(file_path.data(), std::ios_base::binary);
if (infile.fail()) {
std::stringstream err_msg;
err_msg << "Error: open local file '" << file_path << " failed, "
<< strerror(errno);
throw SegcoreError(FileOpenFailed, err_msg.str());
}
infile.seekg(0, std::ios::beg);
if (!message.ParseFromIstream(&infile)) {
std::stringstream err_msg;
err_msg << "Error: parse pb file '" << file_path << " failed, "
<< strerror(errno);
throw SegcoreError(FileReadFailed, err_msg.str());
}
infile.close();
}
milvus::proto::clustering::AnalyzeInfo
transforConfigToPB(const Config& config) {
milvus::proto::clustering::AnalyzeInfo analyze_info;
analyze_info.set_num_clusters(config["num_clusters"]);
analyze_info.set_max_cluster_ratio(config["max_cluster_ratio"]);
analyze_info.set_min_cluster_ratio(config["min_cluster_ratio"]);
analyze_info.set_max_cluster_size(config["max_cluster_size"]);
auto& num_rows = *analyze_info.mutable_num_rows();
for (const auto& [k, v] :
milvus::index::GetValueFromConfig<std::map<int64_t, int64_t>>(
config, "num_rows")
.value()) {
num_rows[k] = v;
}
auto& insert_files = *analyze_info.mutable_insert_files();
auto insert_files_map =
milvus::index::GetValueFromConfig<
std::map<int64_t, std::vector<std::string>>>(config, "insert_files")
.value();
for (const auto& [k, v] : insert_files_map) {
for (auto i = 0; i < v.size(); i++)
insert_files[k].add_insert_files(v[i]);
}
analyze_info.set_dim(config["dim"]);
analyze_info.set_train_size(config["train_size"]);
return analyze_info;
}
template <typename T>
void
CheckResultCorrectness(
const milvus::clustering::KmeansClusteringPtr& clusteringJob,
int64_t segment_id,
int64_t segment_id2,
int64_t dim,
int64_t nb,
int expected_num_clusters,
bool check_centroids) {
std::string centroids_path_prefix =
clusteringJob->GetRemoteCentroidsObjectPrefix();
std::string centroids_name = std::string(CENTROIDS_NAME);
std::string centroid_path = centroids_path_prefix + "/" + centroids_name;
milvus::proto::clustering::ClusteringCentroidsStats stats;
ReadPBFile(centroid_path, stats);
std::vector<T> centroids;
for (const auto& centroid : stats.centroids()) {
const auto& float_vector = centroid.float_vector();
for (float value : float_vector.data()) {
centroids.emplace_back(T(value));
}
}
ASSERT_EQ(centroids.size(), expected_num_clusters * dim);
std::string offset_mapping_name = std::string(OFFSET_MAPPING_NAME);
std::string centroid_id_mapping_path =
clusteringJob->GetRemoteCentroidIdMappingObjectPrefix(segment_id) +
"/" + offset_mapping_name;
milvus::proto::clustering::ClusteringCentroidIdMappingStats mapping_stats;
std::string centroid_id_mapping_path2 =
clusteringJob->GetRemoteCentroidIdMappingObjectPrefix(segment_id2) +
"/" + offset_mapping_name;
milvus::proto::clustering::ClusteringCentroidIdMappingStats mapping_stats2;
ReadPBFile(centroid_id_mapping_path, mapping_stats);
ReadPBFile(centroid_id_mapping_path2, mapping_stats2);
std::vector<uint32_t> centroid_id_mapping;
std::vector<int64_t> num_in_centroid;
for (const auto id : mapping_stats.centroid_id_mapping()) {
centroid_id_mapping.emplace_back(id);
ASSERT_TRUE(id < expected_num_clusters);
}
ASSERT_EQ(centroid_id_mapping.size(), nb);
for (const auto num : mapping_stats.num_in_centroid()) {
num_in_centroid.emplace_back(num);
}
ASSERT_EQ(
std::accumulate(num_in_centroid.begin(), num_in_centroid.end(), 0), nb);
// second id mapping should be the same with the first one since the segment data is the same
if (check_centroids) {
for (int64_t i = 0; i < mapping_stats2.centroid_id_mapping_size();
i++) {
ASSERT_EQ(mapping_stats2.centroid_id_mapping(i),
centroid_id_mapping[i]);
}
for (int64_t i = 0; i < mapping_stats2.num_in_centroid_size(); i++) {
ASSERT_EQ(mapping_stats2.num_in_centroid(i), num_in_centroid[i]);
}
}
}
template <typename T, DataType dtype>
void
test_run() {
int64_t collection_id = 1;
int64_t partition_id = 2;
int64_t segment_id = 3;
int64_t segment_id2 = 4;
int64_t field_id = 101;
int64_t index_build_id = 1000;
int64_t index_version = 10000;
int64_t dim = 100;
int64_t nb = 10000;
auto field_meta =
gen_field_meta(collection_id, partition_id, segment_id, field_id);
auto index_meta =
gen_index_meta(segment_id, field_id, index_build_id, index_version);
std::string root_path = "/tmp/test-kmeans-clustering/";
auto storage_config = gen_local_storage_config(root_path);
auto cm = storage::CreateChunkManager(storage_config);
std::vector<T> data_gen(nb * dim);
for (int64_t i = 0; i < nb * dim; ++i) {
data_gen[i] = rand();
}
auto field_data = storage::CreateFieldData(dtype, dim);
field_data->FillFieldData(data_gen.data(), data_gen.size() / dim);
storage::InsertData insert_data(field_data);
insert_data.SetFieldDataMeta(field_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::Remote);
auto get_binlog_path = [=](int64_t log_id) {
return fmt::format("{}/{}/{}/{}/{}",
collection_id,
partition_id,
segment_id,
field_id,
log_id);
};
auto log_path = get_binlog_path(0);
auto cm_w = ChunkManagerWrapper(cm);
cm_w.Write(log_path, serialized_bytes.data(), serialized_bytes.size());
storage::FileManagerContext ctx(field_meta, index_meta, cm);
std::map<int64_t, std::vector<std::string>> remote_files;
std::map<int64_t, int64_t> num_rows;
// two segments
remote_files[segment_id] = {log_path};
remote_files[segment_id2] = {log_path};
num_rows[segment_id] = nb;
num_rows[segment_id2] = nb;
Config config;
config["max_cluster_ratio"] = 10.0;
config["max_cluster_size"] = 5L * 1024 * 1024 * 1024;
// no need to sample train data
{
config["min_cluster_ratio"] = 0.01;
config["insert_files"] = remote_files;
config["num_clusters"] = 8;
config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB
config["dim"] = dim;
config["num_rows"] = num_rows;
auto clusteringJob =
std::make_unique<clustering::KmeansClustering>(ctx);
clusteringJob->Run<T>(transforConfigToPB(config));
CheckResultCorrectness<T>(clusteringJob,
segment_id,
segment_id2,
dim,
nb,
config["num_clusters"],
true);
}
{
config["min_cluster_ratio"] = 0.01;
config["insert_files"] = remote_files;
config["num_clusters"] = 200;
config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB
config["dim"] = dim;
config["num_rows"] = num_rows;
auto clusteringJob =
std::make_unique<clustering::KmeansClustering>(ctx);
clusteringJob->Run<T>(transforConfigToPB(config));
CheckResultCorrectness<T>(clusteringJob,
segment_id,
segment_id2,
dim,
nb,
config["num_clusters"],
true);
}
// num clusters larger than train num
{
EXPECT_THROW(
try {
config["min_cluster_ratio"] = 0.01;
config["insert_files"] = remote_files;
config["num_clusters"] = 100000;
config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB
config["dim"] = dim;
config["num_rows"] = num_rows;
auto clusteringJob =
std::make_unique<clustering::KmeansClustering>(ctx);
clusteringJob->Run<T>(transforConfigToPB(config));
} catch (SegcoreError& e) {
ASSERT_EQ(e.get_error_code(), ErrorCode::ClusterSkip);
throw e;
},
SegcoreError);
}
// data skew
{
EXPECT_THROW(
try {
config["min_cluster_ratio"] = 0.98;
config["insert_files"] = remote_files;
config["num_clusters"] = 8;
config["train_size"] = 25L * 1024 * 1024 * 1024; // 25GB
config["dim"] = dim;
config["num_rows"] = num_rows;
auto clusteringJob =
std::make_unique<clustering::KmeansClustering>(ctx);
clusteringJob->Run<T>(transforConfigToPB(config));
} catch (SegcoreError& e) {
ASSERT_EQ(e.get_error_code(), ErrorCode::ClusterSkip);
throw e;
},
SegcoreError);
}
// need to sample train data case1
{
config["min_cluster_ratio"] = 0.01;
config["insert_files"] = remote_files;
config["num_clusters"] = 8;
config["train_size"] = 1536L * 1024; // 1.5MB
config["dim"] = dim;
config["num_rows"] = num_rows;
auto clusteringJob =
std::make_unique<clustering::KmeansClustering>(ctx);
clusteringJob->Run<T>(transforConfigToPB(config));
CheckResultCorrectness<T>(clusteringJob,
segment_id,
segment_id2,
dim,
nb,
config["num_clusters"],
true);
}
// need to sample train data case2
{
config["min_cluster_ratio"] = 0.01;
config["insert_files"] = remote_files;
config["num_clusters"] = 8;
config["train_size"] = 6L * 1024 * 1024; // 6MB
config["dim"] = dim;
config["num_rows"] = num_rows;
auto clusteringJob =
std::make_unique<clustering::KmeansClustering>(ctx);
clusteringJob->Run<T>(transforConfigToPB(config));
CheckResultCorrectness<T>(clusteringJob,
segment_id,
segment_id2,
dim,
nb,
config["num_clusters"],
true);
}
}
TEST(MajorCompaction, Naive) {
test_run<float, DataType::VECTOR_FLOAT>();
}

View File

@ -23,6 +23,7 @@
#include "storage/Types.h"
#include "storage/InsertData.h"
#include "storage/ThreadPools.h"
#include <boost/filesystem.hpp>
using milvus::DataType;
using milvus::FieldDataPtr;
@ -137,4 +138,61 @@ PutFieldData(milvus::storage::ChunkManager* remote_chunk_manager,
return remote_paths_to_size;
}
auto
gen_field_meta(int64_t collection_id = 1,
int64_t partition_id = 2,
int64_t segment_id = 3,
int64_t field_id = 101) -> milvus::storage::FieldDataMeta {
return milvus::storage::FieldDataMeta{
.collection_id = collection_id,
.partition_id = partition_id,
.segment_id = segment_id,
.field_id = field_id,
};
}
auto
gen_index_meta(int64_t segment_id = 3,
int64_t field_id = 101,
int64_t index_build_id = 1000,
int64_t index_version = 10000) -> milvus::storage::IndexMeta {
return milvus::storage::IndexMeta{
.segment_id = segment_id,
.field_id = field_id,
.build_id = index_build_id,
.index_version = index_version,
};
}
auto
gen_local_storage_config(const std::string& root_path)
-> milvus::storage::StorageConfig {
auto ret = milvus::storage::StorageConfig{};
ret.storage_type = "local";
ret.root_path = root_path;
return ret;
}
struct ChunkManagerWrapper {
ChunkManagerWrapper(milvus::storage::ChunkManagerPtr cm) : cm_(cm) {
}
~ChunkManagerWrapper() {
for (const auto& file : written_) {
cm_->Remove(file);
}
boost::filesystem::remove_all(cm_->GetRootPath());
}
void
Write(const std::string& filepath, void* buf, uint64_t len) {
written_.insert(filepath);
cm_->Write(filepath, buf, len);
}
const milvus::storage::ChunkManagerPtr cm_;
std::unordered_set<std::string> written_;
};
} // namespace

View File

@ -0,0 +1,182 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"fmt"
"sync"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
type analyzeMeta struct {
sync.RWMutex
ctx context.Context
catalog metastore.DataCoordCatalog
// taskID -> analyzeStats
// TODO: when to mark as dropped?
tasks map[int64]*indexpb.AnalyzeTask
}
func newAnalyzeMeta(ctx context.Context, catalog metastore.DataCoordCatalog) (*analyzeMeta, error) {
mt := &analyzeMeta{
ctx: ctx,
catalog: catalog,
tasks: make(map[int64]*indexpb.AnalyzeTask),
}
if err := mt.reloadFromKV(); err != nil {
return nil, err
}
return mt, nil
}
func (m *analyzeMeta) reloadFromKV() error {
record := timerecord.NewTimeRecorder("analyzeMeta-reloadFromKV")
// load analyze stats
analyzeTasks, err := m.catalog.ListAnalyzeTasks(m.ctx)
if err != nil {
log.Warn("analyzeMeta reloadFromKV load analyze tasks failed", zap.Error(err))
return err
}
for _, analyzeTask := range analyzeTasks {
m.tasks[analyzeTask.TaskID] = analyzeTask
}
log.Info("analyzeMeta reloadFromKV done", zap.Duration("duration", record.ElapseSpan()))
return nil
}
func (m *analyzeMeta) saveTask(newTask *indexpb.AnalyzeTask) error {
if err := m.catalog.SaveAnalyzeTask(m.ctx, newTask); err != nil {
return err
}
m.tasks[newTask.TaskID] = newTask
return nil
}
func (m *analyzeMeta) GetTask(taskID int64) *indexpb.AnalyzeTask {
m.RLock()
defer m.RUnlock()
return m.tasks[taskID]
}
func (m *analyzeMeta) AddAnalyzeTask(task *indexpb.AnalyzeTask) error {
m.Lock()
defer m.Unlock()
log.Info("add analyze task", zap.Int64("taskID", task.TaskID),
zap.Int64("collectionID", task.CollectionID), zap.Int64("partitionID", task.PartitionID))
return m.saveTask(task)
}
func (m *analyzeMeta) DropAnalyzeTask(taskID int64) error {
m.Lock()
defer m.Unlock()
log.Info("drop analyze task", zap.Int64("taskID", taskID))
if err := m.catalog.DropAnalyzeTask(m.ctx, taskID); err != nil {
log.Warn("drop analyze task by catalog failed", zap.Int64("taskID", taskID),
zap.Error(err))
return err
}
delete(m.tasks, taskID)
return nil
}
func (m *analyzeMeta) UpdateVersion(taskID int64) error {
m.Lock()
defer m.Unlock()
t, ok := m.tasks[taskID]
if !ok {
return fmt.Errorf("there is no task with taskID: %d", taskID)
}
cloneT := proto.Clone(t).(*indexpb.AnalyzeTask)
cloneT.Version++
log.Info("update task version", zap.Int64("taskID", taskID), zap.Int64("newVersion", cloneT.Version))
return m.saveTask(cloneT)
}
func (m *analyzeMeta) BuildingTask(taskID, nodeID int64) error {
m.Lock()
defer m.Unlock()
t, ok := m.tasks[taskID]
if !ok {
return fmt.Errorf("there is no task with taskID: %d", taskID)
}
cloneT := proto.Clone(t).(*indexpb.AnalyzeTask)
cloneT.NodeID = nodeID
cloneT.State = indexpb.JobState_JobStateInProgress
log.Info("task will be building", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID))
return m.saveTask(cloneT)
}
func (m *analyzeMeta) FinishTask(taskID int64, result *indexpb.AnalyzeResult) error {
m.Lock()
defer m.Unlock()
t, ok := m.tasks[taskID]
if !ok {
return fmt.Errorf("there is no task with taskID: %d", taskID)
}
log.Info("finish task meta...", zap.Int64("taskID", taskID), zap.String("state", result.GetState().String()),
zap.String("failReason", result.GetFailReason()))
cloneT := proto.Clone(t).(*indexpb.AnalyzeTask)
cloneT.State = result.GetState()
cloneT.FailReason = result.GetFailReason()
cloneT.CentroidsFile = result.GetCentroidsFile()
return m.saveTask(cloneT)
}
func (m *analyzeMeta) GetAllTasks() map[int64]*indexpb.AnalyzeTask {
m.RLock()
defer m.RUnlock()
return m.tasks
}
func (m *analyzeMeta) CheckCleanAnalyzeTask(taskID UniqueID) (bool, *indexpb.AnalyzeTask) {
m.RLock()
defer m.RUnlock()
if t, ok := m.tasks[taskID]; ok {
if t.State == indexpb.JobState_JobStateFinished {
return true, t
}
return false, t
}
return true, nil
}

View File

@ -0,0 +1,267 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/proto/indexpb"
)
type AnalyzeMetaSuite struct {
suite.Suite
collectionID int64
partitionID int64
fieldID int64
segmentIDs []int64
}
func (s *AnalyzeMetaSuite) initParams() {
s.collectionID = 100
s.partitionID = 101
s.fieldID = 102
s.segmentIDs = []int64{1000, 1001, 1002, 1003}
}
func (s *AnalyzeMetaSuite) Test_AnalyzeMeta() {
s.initParams()
catalog := mocks.NewDataCoordCatalog(s.T())
catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return([]*indexpb.AnalyzeTask{
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 1,
State: indexpb.JobState_JobStateNone,
},
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 2,
State: indexpb.JobState_JobStateInit,
},
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 3,
State: indexpb.JobState_JobStateInProgress,
},
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 4,
State: indexpb.JobState_JobStateRetry,
},
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 5,
State: indexpb.JobState_JobStateFinished,
},
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 6,
State: indexpb.JobState_JobStateFailed,
},
}, nil)
catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().DropAnalyzeTask(mock.Anything, mock.Anything).Return(nil)
ctx := context.Background()
am, err := newAnalyzeMeta(ctx, catalog)
s.NoError(err)
s.Equal(6, len(am.GetAllTasks()))
s.Run("GetTask", func() {
t := am.GetTask(1)
s.NotNil(t)
t = am.GetTask(100)
s.Nil(t)
})
s.Run("AddAnalyzeTask", func() {
t := &indexpb.AnalyzeTask{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 7,
}
err := am.AddAnalyzeTask(t)
s.NoError(err)
s.Equal(7, len(am.GetAllTasks()))
err = am.AddAnalyzeTask(t)
s.NoError(err)
s.Equal(7, len(am.GetAllTasks()))
})
s.Run("DropAnalyzeTask", func() {
err := am.DropAnalyzeTask(7)
s.NoError(err)
s.Equal(6, len(am.GetAllTasks()))
})
s.Run("UpdateVersion", func() {
err := am.UpdateVersion(1)
s.NoError(err)
s.Equal(int64(1), am.GetTask(1).Version)
})
s.Run("BuildingTask", func() {
err := am.BuildingTask(1, 1)
s.NoError(err)
s.Equal(indexpb.JobState_JobStateInProgress, am.GetTask(1).State)
})
s.Run("FinishTask", func() {
err := am.FinishTask(1, &indexpb.AnalyzeResult{
TaskID: 1,
State: indexpb.JobState_JobStateFinished,
})
s.NoError(err)
s.Equal(indexpb.JobState_JobStateFinished, am.GetTask(1).State)
})
}
func (s *AnalyzeMetaSuite) Test_failCase() {
s.initParams()
catalog := mocks.NewDataCoordCatalog(s.T())
catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, errors.New("error")).Once()
ctx := context.Background()
am, err := newAnalyzeMeta(ctx, catalog)
s.Error(err)
s.Nil(am)
catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return([]*indexpb.AnalyzeTask{
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 1,
State: indexpb.JobState_JobStateInit,
},
{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 2,
State: indexpb.JobState_JobStateFinished,
},
}, nil)
am, err = newAnalyzeMeta(ctx, catalog)
s.NoError(err)
s.NotNil(am)
s.Equal(2, len(am.GetAllTasks()))
catalog.EXPECT().SaveAnalyzeTask(mock.Anything, mock.Anything).Return(errors.New("error"))
catalog.EXPECT().DropAnalyzeTask(mock.Anything, mock.Anything).Return(errors.New("error"))
s.Run("AddAnalyzeTask", func() {
t := &indexpb.AnalyzeTask{
CollectionID: s.collectionID,
PartitionID: s.partitionID,
FieldID: s.fieldID,
SegmentIDs: s.segmentIDs,
TaskID: 1111,
}
err := am.AddAnalyzeTask(t)
s.Error(err)
s.Nil(am.GetTask(1111))
})
s.Run("DropAnalyzeTask", func() {
err := am.DropAnalyzeTask(1)
s.Error(err)
s.NotNil(am.GetTask(1))
})
s.Run("UpdateVersion", func() {
err := am.UpdateVersion(777)
s.Error(err)
err = am.UpdateVersion(1)
s.Error(err)
s.Equal(int64(0), am.GetTask(1).Version)
})
s.Run("BuildingTask", func() {
err := am.BuildingTask(777, 1)
s.Error(err)
err = am.BuildingTask(1, 1)
s.Error(err)
s.Equal(int64(0), am.GetTask(1).NodeID)
s.Equal(indexpb.JobState_JobStateInit, am.GetTask(1).State)
})
s.Run("FinishTask", func() {
err := am.FinishTask(777, nil)
s.Error(err)
err = am.FinishTask(1, &indexpb.AnalyzeResult{
TaskID: 1,
State: indexpb.JobState_JobStateFinished,
})
s.Error(err)
s.Equal(indexpb.JobState_JobStateInit, am.GetTask(1).State)
})
s.Run("CheckCleanAnalyzeTask", func() {
canRecycle, t := am.CheckCleanAnalyzeTask(1)
s.False(canRecycle)
s.Equal(indexpb.JobState_JobStateInit, t.GetState())
canRecycle, t = am.CheckCleanAnalyzeTask(777)
s.True(canRecycle)
s.Nil(t)
canRecycle, t = am.CheckCleanAnalyzeTask(2)
s.True(canRecycle)
s.Equal(indexpb.JobState_JobStateFinished, t.GetState())
})
}
func TestAnalyzeMeta(t *testing.T) {
suite.Run(t, new(AnalyzeMetaSuite))
}

View File

@ -620,7 +620,7 @@ func (gc *garbageCollector) recycleUnusedIndexFiles(ctx context.Context) {
}
logger = logger.With(zap.Int64("buildID", buildID))
logger.Info("garbageCollector will recycle index files")
canRecycle, segIdx := gc.meta.indexMeta.CleanSegmentIndex(buildID)
canRecycle, segIdx := gc.meta.indexMeta.CheckCleanSegmentIndex(buildID)
if !canRecycle {
// Even if the index is marked as deleted, the index file will not be recycled, wait for the next gc,
// and delete all index files about the buildID at one time.

View File

@ -53,6 +53,7 @@ func (s *ImportCheckerSuite) SetupTest() {
catalog.EXPECT().ListIndexes(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
cluster := NewMockCluster(s.T())
alloc := NewNMockAllocator(s.T())

View File

@ -57,6 +57,7 @@ func (s *ImportSchedulerSuite) SetupTest() {
s.catalog.EXPECT().ListIndexes(mock.Anything).Return(nil, nil)
s.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return(nil, nil)
s.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
s.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
s.cluster = NewMockCluster(s.T())
s.alloc = NewNMockAllocator(s.T())

View File

@ -153,6 +153,7 @@ func TestImportUtil_AssembleRequest(t *testing.T) {
catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return(nil, nil)
catalog.EXPECT().AddSegment(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
alloc := NewNMockAllocator(t)
alloc.EXPECT().allocN(mock.Anything).RunAndReturn(func(n int64) (int64, int64, error) {
@ -233,6 +234,7 @@ func TestImportUtil_CheckDiskQuota(t *testing.T) {
catalog.EXPECT().ListChannelCheckpoint(mock.Anything).Return(nil, nil)
catalog.EXPECT().AddSegment(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
imeta, err := NewImportMeta(catalog)
assert.NoError(t, err)
@ -408,6 +410,7 @@ func TestImportUtil_GetImportProgress(t *testing.T) {
catalog.EXPECT().AddSegment(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().AlterSegments(mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
imeta, err := NewImportMeta(catalog)
assert.NoError(t, err)

View File

@ -1,575 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"path"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
itypeutil "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/lock"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type indexTaskState int32
const (
// when we receive a index task
indexTaskInit indexTaskState = iota
// we've sent index task to scheduler, and wait for building index.
indexTaskInProgress
// task done, wait to be cleaned
indexTaskDone
// index task need to retry.
indexTaskRetry
reqTimeoutInterval = time.Second * 10
)
var TaskStateNames = map[indexTaskState]string{
0: "Init",
1: "InProgress",
2: "Done",
3: "Retry",
}
func (x indexTaskState) String() string {
ret, ok := TaskStateNames[x]
if !ok {
return "None"
}
return ret
}
type indexBuilder struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
taskMutex lock.RWMutex
scheduleDuration time.Duration
// TODO @xiaocai2333: use priority queue
tasks map[int64]indexTaskState
notifyChan chan struct{}
meta *meta
policy buildIndexPolicy
nodeManager *IndexNodeManager
chunkManager storage.ChunkManager
indexEngineVersionManager IndexEngineVersionManager
handler Handler
}
func newIndexBuilder(
ctx context.Context,
metaTable *meta, nodeManager *IndexNodeManager,
chunkManager storage.ChunkManager,
indexEngineVersionManager IndexEngineVersionManager,
handler Handler,
) *indexBuilder {
ctx, cancel := context.WithCancel(ctx)
ib := &indexBuilder{
ctx: ctx,
cancel: cancel,
meta: metaTable,
tasks: make(map[int64]indexTaskState),
notifyChan: make(chan struct{}, 1),
scheduleDuration: Params.DataCoordCfg.IndexTaskSchedulerInterval.GetAsDuration(time.Millisecond),
policy: defaultBuildIndexPolicy,
nodeManager: nodeManager,
chunkManager: chunkManager,
handler: handler,
indexEngineVersionManager: indexEngineVersionManager,
}
ib.reloadFromKV()
return ib
}
func (ib *indexBuilder) Start() {
ib.wg.Add(1)
go ib.schedule()
}
func (ib *indexBuilder) Stop() {
ib.cancel()
ib.wg.Wait()
}
func (ib *indexBuilder) reloadFromKV() {
segments := ib.meta.GetAllSegmentsUnsafe()
for _, segment := range segments {
for _, segIndex := range ib.meta.indexMeta.getSegmentIndexes(segment.ID) {
if segIndex.IsDeleted {
continue
}
if segIndex.IndexState == commonpb.IndexState_Unissued {
ib.tasks[segIndex.BuildID] = indexTaskInit
} else if segIndex.IndexState == commonpb.IndexState_InProgress {
ib.tasks[segIndex.BuildID] = indexTaskInProgress
}
}
}
}
// notify is an unblocked notify function
func (ib *indexBuilder) notify() {
select {
case ib.notifyChan <- struct{}{}:
default:
}
}
func (ib *indexBuilder) enqueue(buildID UniqueID) {
defer ib.notify()
ib.taskMutex.Lock()
defer ib.taskMutex.Unlock()
if _, ok := ib.tasks[buildID]; !ok {
ib.tasks[buildID] = indexTaskInit
}
log.Info("indexBuilder enqueue task", zap.Int64("buildID", buildID))
}
func (ib *indexBuilder) schedule() {
// receive notifyChan
// time ticker
log.Ctx(ib.ctx).Info("index builder schedule loop start")
defer ib.wg.Done()
ticker := time.NewTicker(ib.scheduleDuration)
defer ticker.Stop()
for {
select {
case <-ib.ctx.Done():
log.Ctx(ib.ctx).Warn("index builder ctx done")
return
case _, ok := <-ib.notifyChan:
if ok {
ib.run()
}
// !ok means indexBuild is closed.
case <-ticker.C:
ib.run()
}
}
}
func (ib *indexBuilder) run() {
ib.taskMutex.RLock()
buildIDs := make([]UniqueID, 0, len(ib.tasks))
for tID := range ib.tasks {
buildIDs = append(buildIDs, tID)
}
ib.taskMutex.RUnlock()
if len(buildIDs) > 0 {
log.Ctx(ib.ctx).Info("index builder task schedule", zap.Int("task num", len(buildIDs)))
}
ib.policy(buildIDs)
for _, buildID := range buildIDs {
ok := ib.process(buildID)
if !ok {
log.Ctx(ib.ctx).Info("there is no idle indexing node, wait a minute...")
break
}
}
}
func getBinLogIDs(segment *SegmentInfo, fieldID int64) []int64 {
binlogIDs := make([]int64, 0)
for _, fieldBinLog := range segment.GetBinlogs() {
if fieldBinLog.GetFieldID() == fieldID {
for _, binLog := range fieldBinLog.GetBinlogs() {
binlogIDs = append(binlogIDs, binLog.GetLogID())
}
break
}
}
return binlogIDs
}
func (ib *indexBuilder) process(buildID UniqueID) bool {
ib.taskMutex.RLock()
state := ib.tasks[buildID]
ib.taskMutex.RUnlock()
updateStateFunc := func(buildID UniqueID, state indexTaskState) {
ib.taskMutex.Lock()
defer ib.taskMutex.Unlock()
ib.tasks[buildID] = state
}
deleteFunc := func(buildID UniqueID) {
ib.taskMutex.Lock()
defer ib.taskMutex.Unlock()
delete(ib.tasks, buildID)
}
meta, exist := ib.meta.indexMeta.GetIndexJob(buildID)
if !exist {
log.Ctx(ib.ctx).Debug("index task has not exist in meta table, remove task", zap.Int64("buildID", buildID))
deleteFunc(buildID)
return true
}
switch state {
case indexTaskInit:
segment := ib.meta.GetSegment(meta.SegmentID)
if !isSegmentHealthy(segment) || !ib.meta.indexMeta.IsIndexExist(meta.CollectionID, meta.IndexID) {
log.Ctx(ib.ctx).Info("task is no need to build index, remove it", zap.Int64("buildID", buildID))
if err := ib.meta.indexMeta.DeleteTask(buildID); err != nil {
log.Ctx(ib.ctx).Warn("IndexCoord delete index failed", zap.Int64("buildID", buildID), zap.Error(err))
return false
}
deleteFunc(buildID)
return true
}
indexParams := ib.meta.indexMeta.GetIndexParams(meta.CollectionID, meta.IndexID)
indexType := GetIndexType(indexParams)
if isFlatIndex(indexType) || meta.NumRows < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() {
log.Ctx(ib.ctx).Info("segment does not need index really", zap.Int64("buildID", buildID),
zap.Int64("segmentID", meta.SegmentID), zap.Int64("num rows", meta.NumRows))
if err := ib.meta.indexMeta.FinishTask(&indexpb.IndexTaskInfo{
BuildID: buildID,
State: commonpb.IndexState_Finished,
IndexFileKeys: nil,
SerializedSize: 0,
FailReason: "",
}); err != nil {
log.Ctx(ib.ctx).Warn("IndexCoord update index state fail", zap.Int64("buildID", buildID), zap.Error(err))
return false
}
updateStateFunc(buildID, indexTaskDone)
return true
}
// peek client
// if all IndexNodes are executing task, wait for one of them to finish the task.
nodeID, client := ib.nodeManager.PeekClient(meta)
if client == nil {
log.Ctx(ib.ctx).WithRateGroup("dc.indexBuilder", 1, 60).RatedInfo(5, "index builder peek client error, there is no available")
return false
}
// update version and set nodeID
if err := ib.meta.indexMeta.UpdateVersion(buildID, nodeID); err != nil {
log.Ctx(ib.ctx).Warn("index builder update index version failed", zap.Int64("build", buildID), zap.Error(err))
return false
}
// vector index build needs information of optional scalar fields data
optionalFields := make([]*indexpb.OptionalFieldInfo, 0)
if Params.CommonCfg.EnableMaterializedView.GetAsBool() {
colSchema := ib.meta.GetCollection(meta.CollectionID).Schema
if colSchema != nil {
hasPartitionKey := typeutil.HasPartitionKey(colSchema)
if hasPartitionKey {
partitionKeyField, err := typeutil.GetPartitionKeyFieldSchema(colSchema)
if partitionKeyField == nil || err != nil {
log.Ctx(ib.ctx).Warn("index builder get partition key field failed", zap.Int64("build", buildID), zap.Error(err))
} else {
if typeutil.IsFieldDataTypeSupportMaterializedView(partitionKeyField) {
optionalFields = append(optionalFields, &indexpb.OptionalFieldInfo{
FieldID: partitionKeyField.FieldID,
FieldName: partitionKeyField.Name,
FieldType: int32(partitionKeyField.DataType),
DataIds: getBinLogIDs(segment, partitionKeyField.FieldID),
})
}
}
}
}
}
typeParams := ib.meta.indexMeta.GetTypeParams(meta.CollectionID, meta.IndexID)
var storageConfig *indexpb.StorageConfig
if Params.CommonCfg.StorageType.GetValue() == "local" {
storageConfig = &indexpb.StorageConfig{
RootPath: Params.LocalStorageCfg.Path.GetValue(),
StorageType: Params.CommonCfg.StorageType.GetValue(),
}
} else {
storageConfig = &indexpb.StorageConfig{
Address: Params.MinioCfg.Address.GetValue(),
AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(),
SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(),
UseSSL: Params.MinioCfg.UseSSL.GetAsBool(),
SslCACert: Params.MinioCfg.SslCACert.GetValue(),
BucketName: Params.MinioCfg.BucketName.GetValue(),
RootPath: Params.MinioCfg.RootPath.GetValue(),
UseIAM: Params.MinioCfg.UseIAM.GetAsBool(),
IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(),
StorageType: Params.CommonCfg.StorageType.GetValue(),
Region: Params.MinioCfg.Region.GetValue(),
UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(),
CloudProvider: Params.MinioCfg.CloudProvider.GetValue(),
RequestTimeoutMs: Params.MinioCfg.RequestTimeoutMs.GetAsInt64(),
}
}
fieldID := ib.meta.indexMeta.GetFieldIDByIndexID(meta.CollectionID, meta.IndexID)
binlogIDs := getBinLogIDs(segment, fieldID)
if isDiskANNIndex(GetIndexType(indexParams)) {
var err error
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)
if err != nil {
log.Ctx(ib.ctx).Warn("failed to append index build params", zap.Int64("buildID", buildID),
zap.Int64("nodeID", nodeID), zap.Error(err))
}
}
var req *indexpb.CreateJobRequest
collectionInfo, err := ib.handler.GetCollection(ib.ctx, segment.GetCollectionID())
if err != nil {
log.Ctx(ib.ctx).Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err))
return false
}
schema := collectionInfo.Schema
var field *schemapb.FieldSchema
for _, f := range schema.Fields {
if f.FieldID == fieldID {
field = f
break
}
}
dim, err := storage.GetDimFromParams(field.TypeParams)
if err != nil {
log.Ctx(ib.ctx).Warn("failed to get dim from field type params",
zap.String("field type", field.GetDataType().String()), zap.Error(err))
// don't return, maybe field is scalar field or sparseFloatVector
}
if Params.CommonCfg.EnableStorageV2.GetAsBool() {
storePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue(), segment.GetID())
if err != nil {
log.Ctx(ib.ctx).Warn("failed to get storage uri", zap.Error(err))
return false
}
indexStorePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue()+"/index", segment.GetID())
if err != nil {
log.Ctx(ib.ctx).Warn("failed to get storage uri", zap.Error(err))
return false
}
req = &indexpb.CreateJobRequest{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
IndexFilePrefix: path.Join(ib.chunkManager.RootPath(), common.SegmentIndexPath),
BuildID: buildID,
IndexVersion: meta.IndexVersion + 1,
StorageConfig: storageConfig,
IndexParams: indexParams,
TypeParams: typeParams,
NumRows: meta.NumRows,
CollectionID: segment.GetCollectionID(),
PartitionID: segment.GetPartitionID(),
SegmentID: segment.GetID(),
FieldID: fieldID,
FieldName: field.Name,
FieldType: field.DataType,
StorePath: storePath,
StoreVersion: segment.GetStorageVersion(),
IndexStorePath: indexStorePath,
Dim: int64(dim),
CurrentIndexVersion: ib.indexEngineVersionManager.GetCurrentIndexEngineVersion(),
DataIds: binlogIDs,
OptionalScalarFields: optionalFields,
Field: field,
}
} else {
req = &indexpb.CreateJobRequest{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
IndexFilePrefix: path.Join(ib.chunkManager.RootPath(), common.SegmentIndexPath),
BuildID: buildID,
IndexVersion: meta.IndexVersion + 1,
StorageConfig: storageConfig,
IndexParams: indexParams,
TypeParams: typeParams,
NumRows: meta.NumRows,
CurrentIndexVersion: ib.indexEngineVersionManager.GetCurrentIndexEngineVersion(),
DataIds: binlogIDs,
CollectionID: segment.GetCollectionID(),
PartitionID: segment.GetPartitionID(),
SegmentID: segment.GetID(),
FieldID: fieldID,
OptionalScalarFields: optionalFields,
Dim: int64(dim),
Field: field,
}
}
if err := ib.assignTask(client, req); err != nil {
// need to release lock then reassign, so set task state to retry
log.Ctx(ib.ctx).Warn("index builder assign task to IndexNode failed", zap.Int64("buildID", buildID),
zap.Int64("nodeID", nodeID), zap.Error(err))
updateStateFunc(buildID, indexTaskRetry)
return false
}
log.Ctx(ib.ctx).Info("index task assigned successfully", zap.Int64("buildID", buildID),
zap.Int64("segmentID", meta.SegmentID), zap.Int64("nodeID", nodeID))
// update index meta state to InProgress
if err := ib.meta.indexMeta.BuildIndex(buildID); err != nil {
// need to release lock then reassign, so set task state to retry
log.Ctx(ib.ctx).Warn("index builder update index meta to InProgress failed", zap.Int64("buildID", buildID),
zap.Int64("nodeID", nodeID), zap.Error(err))
updateStateFunc(buildID, indexTaskRetry)
return false
}
updateStateFunc(buildID, indexTaskInProgress)
case indexTaskDone:
if !ib.dropIndexTask(buildID, meta.NodeID) {
return true
}
deleteFunc(buildID)
case indexTaskRetry:
if !ib.dropIndexTask(buildID, meta.NodeID) {
return true
}
updateStateFunc(buildID, indexTaskInit)
default:
// state: in_progress
updateStateFunc(buildID, ib.getTaskState(buildID, meta.NodeID))
}
return true
}
func (ib *indexBuilder) getTaskState(buildID, nodeID UniqueID) indexTaskState {
client, exist := ib.nodeManager.GetClientByID(nodeID)
if exist {
ctx1, cancel := context.WithTimeout(ib.ctx, reqTimeoutInterval)
defer cancel()
response, err := client.QueryJobs(ctx1, &indexpb.QueryJobsRequest{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
BuildIDs: []int64{buildID},
})
if err != nil {
log.Ctx(ib.ctx).Warn("IndexCoord get jobs info from IndexNode fail", zap.Int64("nodeID", nodeID),
zap.Error(err))
return indexTaskRetry
}
if response.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Ctx(ib.ctx).Warn("IndexCoord get jobs info from IndexNode fail", zap.Int64("nodeID", nodeID),
zap.Int64("buildID", buildID), zap.String("fail reason", response.GetStatus().GetReason()))
return indexTaskRetry
}
// indexInfos length is always one.
for _, info := range response.GetIndexInfos() {
if info.GetBuildID() == buildID {
if info.GetState() == commonpb.IndexState_Failed || info.GetState() == commonpb.IndexState_Finished {
log.Ctx(ib.ctx).Info("this task has been finished", zap.Int64("buildID", info.GetBuildID()),
zap.String("index state", info.GetState().String()))
if err := ib.meta.indexMeta.FinishTask(info); err != nil {
log.Ctx(ib.ctx).Warn("IndexCoord update index state fail", zap.Int64("buildID", info.GetBuildID()),
zap.String("index state", info.GetState().String()), zap.Error(err))
return indexTaskInProgress
}
return indexTaskDone
} else if info.GetState() == commonpb.IndexState_Retry || info.GetState() == commonpb.IndexState_IndexStateNone {
log.Ctx(ib.ctx).Info("this task should be retry", zap.Int64("buildID", buildID), zap.String("fail reason", info.GetFailReason()))
return indexTaskRetry
}
return indexTaskInProgress
}
}
log.Ctx(ib.ctx).Info("this task should be retry, indexNode does not have this task", zap.Int64("buildID", buildID),
zap.Int64("nodeID", nodeID))
return indexTaskRetry
}
// !exist --> node down
log.Ctx(ib.ctx).Info("this task should be retry, indexNode is no longer exist", zap.Int64("buildID", buildID),
zap.Int64("nodeID", nodeID))
return indexTaskRetry
}
func (ib *indexBuilder) dropIndexTask(buildID, nodeID UniqueID) bool {
client, exist := ib.nodeManager.GetClientByID(nodeID)
if exist {
ctx1, cancel := context.WithTimeout(ib.ctx, reqTimeoutInterval)
defer cancel()
status, err := client.DropJobs(ctx1, &indexpb.DropJobsRequest{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
BuildIDs: []UniqueID{buildID},
})
if err != nil {
log.Ctx(ib.ctx).Warn("IndexCoord notify IndexNode drop the index task fail", zap.Int64("buildID", buildID),
zap.Int64("nodeID", nodeID), zap.Error(err))
return false
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
log.Ctx(ib.ctx).Warn("IndexCoord notify IndexNode drop the index task fail", zap.Int64("buildID", buildID),
zap.Int64("nodeID", nodeID), zap.String("fail reason", status.GetReason()))
return false
}
log.Ctx(ib.ctx).Info("IndexCoord notify IndexNode drop the index task success",
zap.Int64("buildID", buildID), zap.Int64("nodeID", nodeID))
return true
}
log.Ctx(ib.ctx).Info("IndexNode no longer exist, no need to drop index task",
zap.Int64("buildID", buildID), zap.Int64("nodeID", nodeID))
return true
}
// assignTask sends the index task to the IndexNode, it has a timeout interval, if the IndexNode doesn't respond within
// the interval, it is considered that the task sending failed.
func (ib *indexBuilder) assignTask(builderClient types.IndexNodeClient, req *indexpb.CreateJobRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), reqTimeoutInterval)
defer cancel()
resp, err := builderClient.CreateJob(ctx, req)
if err == nil {
err = merr.Error(resp)
}
if err != nil {
log.Error("IndexCoord assignmentTasksLoop builderClient.CreateIndex failed", zap.Error(err))
return err
}
return nil
}
func (ib *indexBuilder) nodeDown(nodeID UniqueID) {
defer ib.notify()
metas := ib.meta.indexMeta.GetMetasByNodeID(nodeID)
ib.taskMutex.Lock()
defer ib.taskMutex.Unlock()
for _, meta := range metas {
if ib.tasks[meta.BuildID] != indexTaskDone {
ib.tasks[meta.BuildID] = indexTaskRetry
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -653,18 +653,17 @@ func (m *indexMeta) IsIndexExist(collID, indexID UniqueID) bool {
}
// UpdateVersion updates the version and nodeID of the index meta, whenever the task is built once, the version will be updated once.
func (m *indexMeta) UpdateVersion(buildID UniqueID, nodeID UniqueID) error {
func (m *indexMeta) UpdateVersion(buildID UniqueID) error {
m.Lock()
defer m.Unlock()
log.Debug("IndexCoord metaTable UpdateVersion receive", zap.Int64("buildID", buildID), zap.Int64("nodeID", nodeID))
log.Debug("IndexCoord metaTable UpdateVersion receive", zap.Int64("buildID", buildID))
segIdx, ok := m.buildID2SegmentIndex[buildID]
if !ok {
return fmt.Errorf("there is no index with buildID: %d", buildID)
}
updateFunc := func(segIdx *model.SegmentIndex) error {
segIdx.NodeID = nodeID
segIdx.IndexVersion++
return m.alterSegmentIndexes([]*model.SegmentIndex{segIdx})
}
@ -728,7 +727,7 @@ func (m *indexMeta) DeleteTask(buildID int64) error {
}
// BuildIndex set the index state to be InProgress. It means IndexNode is building the index.
func (m *indexMeta) BuildIndex(buildID UniqueID) error {
func (m *indexMeta) BuildIndex(buildID, nodeID UniqueID) error {
m.Lock()
defer m.Unlock()
@ -738,6 +737,7 @@ func (m *indexMeta) BuildIndex(buildID UniqueID) error {
}
updateFunc := func(segIdx *model.SegmentIndex) error {
segIdx.NodeID = nodeID
segIdx.IndexState = commonpb.IndexState_InProgress
err := m.alterSegmentIndexes([]*model.SegmentIndex{segIdx})
@ -828,7 +828,7 @@ func (m *indexMeta) RemoveIndex(collID, indexID UniqueID) error {
return nil
}
func (m *indexMeta) CleanSegmentIndex(buildID UniqueID) (bool, *model.SegmentIndex) {
func (m *indexMeta) CheckCleanSegmentIndex(buildID UniqueID) (bool, *model.SegmentIndex) {
m.RLock()
defer m.RUnlock()

View File

@ -693,7 +693,8 @@ func TestMeta_MarkIndexAsDeleted(t *testing.T) {
}
func TestMeta_GetSegmentIndexes(t *testing.T) {
m := createMetaTable(&datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)})
catalog := &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}
m := createMeta(catalog, nil, createIndexMeta(catalog))
t.Run("success", func(t *testing.T) {
segIndexes := m.indexMeta.getSegmentIndexes(segID)
@ -1075,18 +1076,18 @@ func TestMeta_UpdateVersion(t *testing.T) {
).Return(errors.New("fail"))
t.Run("success", func(t *testing.T) {
err := m.UpdateVersion(buildID, nodeID)
err := m.UpdateVersion(buildID)
assert.NoError(t, err)
})
t.Run("fail", func(t *testing.T) {
m.catalog = ec
err := m.UpdateVersion(buildID, nodeID)
err := m.UpdateVersion(buildID)
assert.Error(t, err)
})
t.Run("not exist", func(t *testing.T) {
err := m.UpdateVersion(buildID+1, nodeID)
err := m.UpdateVersion(buildID + 1)
assert.Error(t, err)
})
}
@ -1143,18 +1144,18 @@ func TestMeta_BuildIndex(t *testing.T) {
).Return(errors.New("fail"))
t.Run("success", func(t *testing.T) {
err := m.BuildIndex(buildID)
err := m.BuildIndex(buildID, nodeID)
assert.NoError(t, err)
})
t.Run("fail", func(t *testing.T) {
m.catalog = ec
err := m.BuildIndex(buildID)
err := m.BuildIndex(buildID, nodeID)
assert.Error(t, err)
})
t.Run("not exist", func(t *testing.T) {
err := m.BuildIndex(buildID + 1)
err := m.BuildIndex(buildID+1, nodeID)
assert.Error(t, err)
})
}
@ -1330,7 +1331,8 @@ func TestRemoveSegmentIndex(t *testing.T) {
}
func TestIndexMeta_GetUnindexedSegments(t *testing.T) {
m := createMetaTable(&datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)})
catalog := &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}
m := createMeta(catalog, nil, createIndexMeta(catalog))
// normal case
segmentIDs := make([]int64, 0, 11)

View File

@ -48,8 +48,6 @@ func (s *Server) serverID() int64 {
}
func (s *Server) startIndexService(ctx context.Context) {
s.indexBuilder.Start()
s.serverLoopWg.Add(1)
go s.createIndexForSegmentLoop(ctx)
}
@ -73,7 +71,13 @@ func (s *Server) createIndexForSegment(segment *SegmentInfo, indexID UniqueID) e
if err = s.meta.indexMeta.AddSegmentIndex(segIndex); err != nil {
return err
}
s.indexBuilder.enqueue(buildID)
s.taskScheduler.enqueue(&indexBuildTask{
buildID: buildID,
taskInfo: &indexpb.IndexTaskInfo{
BuildID: buildID,
State: commonpb.IndexState_Unissued,
},
})
return nil
}

View File

@ -24,7 +24,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
@ -33,6 +32,16 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)
type WorkerManager interface {
AddNode(nodeID UniqueID, address string) error
RemoveNode(nodeID UniqueID)
StoppingNode(nodeID UniqueID)
PickClient() (UniqueID, types.IndexNodeClient)
ClientSupportDisk() bool
GetAllClients() map[UniqueID]types.IndexNodeClient
GetClientByID(nodeID UniqueID) (types.IndexNodeClient, bool)
}
// IndexNodeManager is used to manage the client of IndexNode.
type IndexNodeManager struct {
nodeClients map[UniqueID]types.IndexNodeClient
@ -98,59 +107,55 @@ func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error {
return nil
}
// PeekClient peeks the client with the least load.
func (nm *IndexNodeManager) PeekClient(meta *model.SegmentIndex) (UniqueID, types.IndexNodeClient) {
allClients := nm.GetAllClients()
if len(allClients) == 0 {
log.Error("there is no IndexNode online")
return -1, nil
}
func (nm *IndexNodeManager) PickClient() (UniqueID, types.IndexNodeClient) {
nm.lock.Lock()
defer nm.lock.Unlock()
// Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected
ctx, cancel := context.WithCancel(nm.ctx)
var (
peekNodeID = UniqueID(0)
nodeMutex = lock.Mutex{}
pickNodeID = UniqueID(0)
nodeMutex = sync.Mutex{}
wg = sync.WaitGroup{}
)
for nodeID, client := range allClients {
nodeID := nodeID
client := client
wg.Add(1)
go func() {
defer wg.Done()
resp, err := client.GetJobStats(ctx, &indexpb.GetJobStatsRequest{})
if err != nil {
log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err))
return
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID),
zap.String("reason", resp.GetStatus().GetReason()))
return
}
if resp.GetTaskSlots() > 0 {
nodeMutex.Lock()
defer nodeMutex.Unlock()
log.Info("peek client success", zap.Int64("nodeID", nodeID))
if peekNodeID == 0 {
peekNodeID = nodeID
for nodeID, client := range nm.nodeClients {
if _, ok := nm.stoppingNodes[nodeID]; !ok {
nodeID := nodeID
client := client
wg.Add(1)
go func() {
defer wg.Done()
resp, err := client.GetJobStats(ctx, &indexpb.GetJobStatsRequest{})
if err != nil {
log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err))
return
}
cancel()
// Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected
return
}
}()
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID),
zap.String("reason", resp.GetStatus().GetReason()))
return
}
if resp.GetTaskSlots() > 0 {
nodeMutex.Lock()
defer nodeMutex.Unlock()
if pickNodeID == 0 {
pickNodeID = nodeID
}
cancel()
// Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected
return
}
}()
}
}
wg.Wait()
cancel()
if peekNodeID != 0 {
log.Info("peek client success", zap.Int64("nodeID", peekNodeID))
return peekNodeID, allClients[peekNodeID]
if pickNodeID != 0 {
log.Info("pick indexNode success", zap.Int64("nodeID", pickNodeID))
return pickNodeID, nm.nodeClients[pickNodeID]
}
log.RatedDebug(5, "peek client fail")
return 0, nil
}

View File

@ -24,7 +24,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types"
@ -34,9 +33,6 @@ import (
func TestIndexNodeManager_AddNode(t *testing.T) {
nm := NewNodeManager(context.Background(), defaultIndexNodeCreatorFunc)
nodeID, client := nm.PeekClient(&model.SegmentIndex{})
assert.Equal(t, int64(-1), nodeID)
assert.Nil(t, client)
t.Run("success", func(t *testing.T) {
err := nm.AddNode(1, "indexnode-1")
@ -49,7 +45,7 @@ func TestIndexNodeManager_AddNode(t *testing.T) {
})
}
func TestIndexNodeManager_PeekClient(t *testing.T) {
func TestIndexNodeManager_PickClient(t *testing.T) {
getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient {
ic := mocks.NewMockIndexNodeClient(t)
ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err)
@ -94,9 +90,9 @@ func TestIndexNodeManager_PeekClient(t *testing.T) {
},
}
nodeID, client := nm.PeekClient(&model.SegmentIndex{})
selectNodeID, client := nm.PickClient()
assert.NotNil(t, client)
assert.Contains(t, []UniqueID{8, 9}, nodeID)
assert.Contains(t, []UniqueID{8, 9}, selectNodeID)
})
}

View File

@ -57,10 +57,15 @@ type CompactionMeta interface {
SetSegmentCompacting(segmentID int64, compacting bool)
CheckAndSetSegmentsCompacting(segmentIDs []int64) (bool, bool)
CompleteCompactionMutation(plan *datapb.CompactionPlan, result *datapb.CompactionPlanResult) ([]*SegmentInfo, *segMetricMutation, error)
SaveCompactionTask(task *datapb.CompactionTask) error
DropCompactionTask(task *datapb.CompactionTask) error
GetCompactionTasks() map[int64][]*datapb.CompactionTask
GetCompactionTasksByTriggerID(triggerID int64) []*datapb.CompactionTask
GetIndexMeta() *indexMeta
GetAnalyzeMeta() *analyzeMeta
GetCompactionTaskMeta() *compactionTaskMeta
}
var _ CompactionMeta = (*meta)(nil)
@ -75,9 +80,22 @@ type meta struct {
chunkManager storage.ChunkManager
indexMeta *indexMeta
analyzeMeta *analyzeMeta
compactionTaskMeta *compactionTaskMeta
}
func (m *meta) GetIndexMeta() *indexMeta {
return m.indexMeta
}
func (m *meta) GetAnalyzeMeta() *analyzeMeta {
return m.analyzeMeta
}
func (m *meta) GetCompactionTaskMeta() *compactionTaskMeta {
return m.compactionTaskMeta
}
type channelCPs struct {
lock.RWMutex
checkpoints map[string]*msgpb.MsgPosition
@ -110,7 +128,12 @@ type collectionInfo struct {
// NewMeta creates meta from provided `kv.TxnKV`
func newMeta(ctx context.Context, catalog metastore.DataCoordCatalog, chunkManager storage.ChunkManager) (*meta, error) {
indexMeta, err := newIndexMeta(ctx, catalog)
im, err := newIndexMeta(ctx, catalog)
if err != nil {
return nil, err
}
am, err := newAnalyzeMeta(ctx, catalog)
if err != nil {
return nil, err
}
@ -119,14 +142,14 @@ func newMeta(ctx context.Context, catalog metastore.DataCoordCatalog, chunkManag
if err != nil {
return nil, err
}
mt := &meta{
ctx: ctx,
catalog: catalog,
collections: make(map[UniqueID]*collectionInfo),
segments: NewSegmentsInfo(),
channelCPs: newChannelCps(),
indexMeta: indexMeta,
indexMeta: im,
analyzeMeta: am,
chunkManager: chunkManager,
compactionTaskMeta: ctm,
}

View File

@ -71,6 +71,7 @@ func (suite *MetaReloadSuite) TestReloadFromKV() {
suite.catalog.EXPECT().ListSegments(mock.Anything).Return(nil, errors.New("mock"))
suite.catalog.EXPECT().ListIndexes(mock.Anything).Return([]*model.Index{}, nil)
suite.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return([]*model.SegmentIndex{}, nil)
suite.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
suite.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
_, err := newMeta(ctx, suite.catalog, nil)
@ -84,6 +85,7 @@ func (suite *MetaReloadSuite) TestReloadFromKV() {
suite.catalog.EXPECT().ListChannelCheckpoint(mock.Anything).Return(nil, errors.New("mock"))
suite.catalog.EXPECT().ListIndexes(mock.Anything).Return([]*model.Index{}, nil)
suite.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return([]*model.SegmentIndex{}, nil)
suite.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
suite.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
_, err := newMeta(ctx, suite.catalog, nil)
@ -94,6 +96,7 @@ func (suite *MetaReloadSuite) TestReloadFromKV() {
defer suite.resetMock()
suite.catalog.EXPECT().ListIndexes(mock.Anything).Return([]*model.Index{}, nil)
suite.catalog.EXPECT().ListSegmentIndexes(mock.Anything).Return([]*model.SegmentIndex{}, nil)
suite.catalog.EXPECT().ListAnalyzeTasks(mock.Anything).Return(nil, nil)
suite.catalog.EXPECT().ListCompactionTask(mock.Anything).Return(nil, nil)
suite.catalog.EXPECT().ListSegments(mock.Anything).Return([]*datapb.SegmentInfo{
{
@ -622,7 +625,7 @@ func TestMeta_Basic(t *testing.T) {
})
t.Run("Test GetCollectionBinlogSize", func(t *testing.T) {
meta := createMetaTable(&datacoord.Catalog{})
meta := createMeta(&datacoord.Catalog{}, nil, createIndexMeta(&datacoord.Catalog{}))
ret := meta.GetCollectionIndexFilesSize()
assert.Equal(t, uint64(0), ret)

View File

@ -178,6 +178,92 @@ func (_c *MockCompactionMeta_DropCompactionTask_Call) RunAndReturn(run func(*dat
return _c
}
// GetAnalyzeMeta provides a mock function with given fields:
func (_m *MockCompactionMeta) GetAnalyzeMeta() *analyzeMeta {
ret := _m.Called()
var r0 *analyzeMeta
if rf, ok := ret.Get(0).(func() *analyzeMeta); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*analyzeMeta)
}
}
return r0
}
// MockCompactionMeta_GetAnalyzeMeta_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAnalyzeMeta'
type MockCompactionMeta_GetAnalyzeMeta_Call struct {
*mock.Call
}
// GetAnalyzeMeta is a helper method to define mock.On call
func (_e *MockCompactionMeta_Expecter) GetAnalyzeMeta() *MockCompactionMeta_GetAnalyzeMeta_Call {
return &MockCompactionMeta_GetAnalyzeMeta_Call{Call: _e.mock.On("GetAnalyzeMeta")}
}
func (_c *MockCompactionMeta_GetAnalyzeMeta_Call) Run(run func()) *MockCompactionMeta_GetAnalyzeMeta_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCompactionMeta_GetAnalyzeMeta_Call) Return(_a0 *analyzeMeta) *MockCompactionMeta_GetAnalyzeMeta_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCompactionMeta_GetAnalyzeMeta_Call) RunAndReturn(run func() *analyzeMeta) *MockCompactionMeta_GetAnalyzeMeta_Call {
_c.Call.Return(run)
return _c
}
// GetCompactionTaskMeta provides a mock function with given fields:
func (_m *MockCompactionMeta) GetCompactionTaskMeta() *compactionTaskMeta {
ret := _m.Called()
var r0 *compactionTaskMeta
if rf, ok := ret.Get(0).(func() *compactionTaskMeta); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*compactionTaskMeta)
}
}
return r0
}
// MockCompactionMeta_GetCompactionTaskMeta_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompactionTaskMeta'
type MockCompactionMeta_GetCompactionTaskMeta_Call struct {
*mock.Call
}
// GetCompactionTaskMeta is a helper method to define mock.On call
func (_e *MockCompactionMeta_Expecter) GetCompactionTaskMeta() *MockCompactionMeta_GetCompactionTaskMeta_Call {
return &MockCompactionMeta_GetCompactionTaskMeta_Call{Call: _e.mock.On("GetCompactionTaskMeta")}
}
func (_c *MockCompactionMeta_GetCompactionTaskMeta_Call) Run(run func()) *MockCompactionMeta_GetCompactionTaskMeta_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCompactionMeta_GetCompactionTaskMeta_Call) Return(_a0 *compactionTaskMeta) *MockCompactionMeta_GetCompactionTaskMeta_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCompactionMeta_GetCompactionTaskMeta_Call) RunAndReturn(run func() *compactionTaskMeta) *MockCompactionMeta_GetCompactionTaskMeta_Call {
_c.Call.Return(run)
return _c
}
// GetCompactionTasks provides a mock function with given fields:
func (_m *MockCompactionMeta) GetCompactionTasks() map[int64][]*datapb.CompactionTask {
ret := _m.Called()
@ -309,6 +395,49 @@ func (_c *MockCompactionMeta_GetHealthySegment_Call) RunAndReturn(run func(int64
return _c
}
// GetIndexMeta provides a mock function with given fields:
func (_m *MockCompactionMeta) GetIndexMeta() *indexMeta {
ret := _m.Called()
var r0 *indexMeta
if rf, ok := ret.Get(0).(func() *indexMeta); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*indexMeta)
}
}
return r0
}
// MockCompactionMeta_GetIndexMeta_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexMeta'
type MockCompactionMeta_GetIndexMeta_Call struct {
*mock.Call
}
// GetIndexMeta is a helper method to define mock.On call
func (_e *MockCompactionMeta_Expecter) GetIndexMeta() *MockCompactionMeta_GetIndexMeta_Call {
return &MockCompactionMeta_GetIndexMeta_Call{Call: _e.mock.On("GetIndexMeta")}
}
func (_c *MockCompactionMeta_GetIndexMeta_Call) Run(run func()) *MockCompactionMeta_GetIndexMeta_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCompactionMeta_GetIndexMeta_Call) Return(_a0 *indexMeta) *MockCompactionMeta_GetIndexMeta_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCompactionMeta_GetIndexMeta_Call) RunAndReturn(run func() *indexMeta) *MockCompactionMeta_GetIndexMeta_Call {
_c.Call.Return(run)
return _c
}
// GetSegment provides a mock function with given fields: segID
func (_m *MockCompactionMeta) GetSegment(segID int64) *SegmentInfo {
ret := _m.Called(segID)

View File

@ -0,0 +1,335 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package datacoord
import (
types "github.com/milvus-io/milvus/internal/types"
mock "github.com/stretchr/testify/mock"
)
// MockWorkerManager is an autogenerated mock type for the WorkerManager type
type MockWorkerManager struct {
mock.Mock
}
type MockWorkerManager_Expecter struct {
mock *mock.Mock
}
func (_m *MockWorkerManager) EXPECT() *MockWorkerManager_Expecter {
return &MockWorkerManager_Expecter{mock: &_m.Mock}
}
// AddNode provides a mock function with given fields: nodeID, address
func (_m *MockWorkerManager) AddNode(nodeID int64, address string) error {
ret := _m.Called(nodeID, address)
var r0 error
if rf, ok := ret.Get(0).(func(int64, string) error); ok {
r0 = rf(nodeID, address)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockWorkerManager_AddNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddNode'
type MockWorkerManager_AddNode_Call struct {
*mock.Call
}
// AddNode is a helper method to define mock.On call
// - nodeID int64
// - address string
func (_e *MockWorkerManager_Expecter) AddNode(nodeID interface{}, address interface{}) *MockWorkerManager_AddNode_Call {
return &MockWorkerManager_AddNode_Call{Call: _e.mock.On("AddNode", nodeID, address)}
}
func (_c *MockWorkerManager_AddNode_Call) Run(run func(nodeID int64, address string)) *MockWorkerManager_AddNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(string))
})
return _c
}
func (_c *MockWorkerManager_AddNode_Call) Return(_a0 error) *MockWorkerManager_AddNode_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWorkerManager_AddNode_Call) RunAndReturn(run func(int64, string) error) *MockWorkerManager_AddNode_Call {
_c.Call.Return(run)
return _c
}
// ClientSupportDisk provides a mock function with given fields:
func (_m *MockWorkerManager) ClientSupportDisk() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockWorkerManager_ClientSupportDisk_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClientSupportDisk'
type MockWorkerManager_ClientSupportDisk_Call struct {
*mock.Call
}
// ClientSupportDisk is a helper method to define mock.On call
func (_e *MockWorkerManager_Expecter) ClientSupportDisk() *MockWorkerManager_ClientSupportDisk_Call {
return &MockWorkerManager_ClientSupportDisk_Call{Call: _e.mock.On("ClientSupportDisk")}
}
func (_c *MockWorkerManager_ClientSupportDisk_Call) Run(run func()) *MockWorkerManager_ClientSupportDisk_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockWorkerManager_ClientSupportDisk_Call) Return(_a0 bool) *MockWorkerManager_ClientSupportDisk_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWorkerManager_ClientSupportDisk_Call) RunAndReturn(run func() bool) *MockWorkerManager_ClientSupportDisk_Call {
_c.Call.Return(run)
return _c
}
// GetAllClients provides a mock function with given fields:
func (_m *MockWorkerManager) GetAllClients() map[int64]types.IndexNodeClient {
ret := _m.Called()
var r0 map[int64]types.IndexNodeClient
if rf, ok := ret.Get(0).(func() map[int64]types.IndexNodeClient); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]types.IndexNodeClient)
}
}
return r0
}
// MockWorkerManager_GetAllClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAllClients'
type MockWorkerManager_GetAllClients_Call struct {
*mock.Call
}
// GetAllClients is a helper method to define mock.On call
func (_e *MockWorkerManager_Expecter) GetAllClients() *MockWorkerManager_GetAllClients_Call {
return &MockWorkerManager_GetAllClients_Call{Call: _e.mock.On("GetAllClients")}
}
func (_c *MockWorkerManager_GetAllClients_Call) Run(run func()) *MockWorkerManager_GetAllClients_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockWorkerManager_GetAllClients_Call) Return(_a0 map[int64]types.IndexNodeClient) *MockWorkerManager_GetAllClients_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWorkerManager_GetAllClients_Call) RunAndReturn(run func() map[int64]types.IndexNodeClient) *MockWorkerManager_GetAllClients_Call {
_c.Call.Return(run)
return _c
}
// GetClientByID provides a mock function with given fields: nodeID
func (_m *MockWorkerManager) GetClientByID(nodeID int64) (types.IndexNodeClient, bool) {
ret := _m.Called(nodeID)
var r0 types.IndexNodeClient
var r1 bool
if rf, ok := ret.Get(0).(func(int64) (types.IndexNodeClient, bool)); ok {
return rf(nodeID)
}
if rf, ok := ret.Get(0).(func(int64) types.IndexNodeClient); ok {
r0 = rf(nodeID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(types.IndexNodeClient)
}
}
if rf, ok := ret.Get(1).(func(int64) bool); ok {
r1 = rf(nodeID)
} else {
r1 = ret.Get(1).(bool)
}
return r0, r1
}
// MockWorkerManager_GetClientByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetClientByID'
type MockWorkerManager_GetClientByID_Call struct {
*mock.Call
}
// GetClientByID is a helper method to define mock.On call
// - nodeID int64
func (_e *MockWorkerManager_Expecter) GetClientByID(nodeID interface{}) *MockWorkerManager_GetClientByID_Call {
return &MockWorkerManager_GetClientByID_Call{Call: _e.mock.On("GetClientByID", nodeID)}
}
func (_c *MockWorkerManager_GetClientByID_Call) Run(run func(nodeID int64)) *MockWorkerManager_GetClientByID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockWorkerManager_GetClientByID_Call) Return(_a0 types.IndexNodeClient, _a1 bool) *MockWorkerManager_GetClientByID_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockWorkerManager_GetClientByID_Call) RunAndReturn(run func(int64) (types.IndexNodeClient, bool)) *MockWorkerManager_GetClientByID_Call {
_c.Call.Return(run)
return _c
}
// PickClient provides a mock function with given fields:
func (_m *MockWorkerManager) PickClient() (int64, types.IndexNodeClient) {
ret := _m.Called()
var r0 int64
var r1 types.IndexNodeClient
if rf, ok := ret.Get(0).(func() (int64, types.IndexNodeClient)); ok {
return rf()
}
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
if rf, ok := ret.Get(1).(func() types.IndexNodeClient); ok {
r1 = rf()
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(types.IndexNodeClient)
}
}
return r0, r1
}
// MockWorkerManager_PickClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PickClient'
type MockWorkerManager_PickClient_Call struct {
*mock.Call
}
// PickClient is a helper method to define mock.On call
func (_e *MockWorkerManager_Expecter) PickClient() *MockWorkerManager_PickClient_Call {
return &MockWorkerManager_PickClient_Call{Call: _e.mock.On("PickClient")}
}
func (_c *MockWorkerManager_PickClient_Call) Run(run func()) *MockWorkerManager_PickClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockWorkerManager_PickClient_Call) Return(_a0 int64, _a1 types.IndexNodeClient) *MockWorkerManager_PickClient_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockWorkerManager_PickClient_Call) RunAndReturn(run func() (int64, types.IndexNodeClient)) *MockWorkerManager_PickClient_Call {
_c.Call.Return(run)
return _c
}
// RemoveNode provides a mock function with given fields: nodeID
func (_m *MockWorkerManager) RemoveNode(nodeID int64) {
_m.Called(nodeID)
}
// MockWorkerManager_RemoveNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveNode'
type MockWorkerManager_RemoveNode_Call struct {
*mock.Call
}
// RemoveNode is a helper method to define mock.On call
// - nodeID int64
func (_e *MockWorkerManager_Expecter) RemoveNode(nodeID interface{}) *MockWorkerManager_RemoveNode_Call {
return &MockWorkerManager_RemoveNode_Call{Call: _e.mock.On("RemoveNode", nodeID)}
}
func (_c *MockWorkerManager_RemoveNode_Call) Run(run func(nodeID int64)) *MockWorkerManager_RemoveNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockWorkerManager_RemoveNode_Call) Return() *MockWorkerManager_RemoveNode_Call {
_c.Call.Return()
return _c
}
func (_c *MockWorkerManager_RemoveNode_Call) RunAndReturn(run func(int64)) *MockWorkerManager_RemoveNode_Call {
_c.Call.Return(run)
return _c
}
// StoppingNode provides a mock function with given fields: nodeID
func (_m *MockWorkerManager) StoppingNode(nodeID int64) {
_m.Called(nodeID)
}
// MockWorkerManager_StoppingNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StoppingNode'
type MockWorkerManager_StoppingNode_Call struct {
*mock.Call
}
// StoppingNode is a helper method to define mock.On call
// - nodeID int64
func (_e *MockWorkerManager_Expecter) StoppingNode(nodeID interface{}) *MockWorkerManager_StoppingNode_Call {
return &MockWorkerManager_StoppingNode_Call{Call: _e.mock.On("StoppingNode", nodeID)}
}
func (_c *MockWorkerManager_StoppingNode_Call) Run(run func(nodeID int64)) *MockWorkerManager_StoppingNode_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockWorkerManager_StoppingNode_Call) Return() *MockWorkerManager_StoppingNode_Call {
_c.Call.Return()
return _c
}
func (_c *MockWorkerManager_StoppingNode_Call) RunAndReturn(run func(int64)) *MockWorkerManager_StoppingNode_Call {
_c.Call.Return(run)
return _c
}
// NewMockWorkerManager creates a new instance of MockWorkerManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockWorkerManager(t interface {
mock.TestingT
Cleanup(func())
}) *MockWorkerManager {
mock := &MockWorkerManager{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -152,10 +152,11 @@ type Server struct {
// indexCoord types.IndexCoord
// segReferManager *SegmentReferenceManager
indexBuilder *indexBuilder
indexNodeManager *IndexNodeManager
indexEngineVersionManager IndexEngineVersionManager
taskScheduler *taskScheduler
// manage ways that data coord access other coord
broker broker.Broker
}
@ -373,6 +374,7 @@ func (s *Server) initDataCoord() error {
}
log.Info("init service discovery done")
s.initTaskScheduler(storageCli)
if Params.DataCoordCfg.EnableCompaction.GetAsBool() {
s.createCompactionHandler()
s.createCompactionTrigger()
@ -385,7 +387,6 @@ func (s *Server) initDataCoord() error {
log.Info("init segment manager done")
s.initGarbageCollection(storageCli)
s.initIndexBuilder(storageCli)
s.importMeta, err = NewImportMeta(s.meta.catalog)
if err != nil {
@ -419,6 +420,7 @@ func (s *Server) Start() error {
}
func (s *Server) startDataCoord() {
s.taskScheduler.Start()
if Params.DataCoordCfg.EnableCompaction.GetAsBool() {
s.compactionHandler.start()
s.compactionTrigger.start()
@ -689,9 +691,9 @@ func (s *Server) initMeta(chunkManager storage.ChunkManager) error {
return retry.Do(s.ctx, reloadEtcdFn, retry.Attempts(connMetaMaxRetryTime))
}
func (s *Server) initIndexBuilder(manager storage.ChunkManager) {
if s.indexBuilder == nil {
s.indexBuilder = newIndexBuilder(s.ctx, s.meta, s.indexNodeManager, manager, s.indexEngineVersionManager, s.handler)
func (s *Server) initTaskScheduler(manager storage.ChunkManager) {
if s.taskScheduler == nil {
s.taskScheduler = newTaskScheduler(s.ctx, s.meta, s.indexNodeManager, manager, s.indexEngineVersionManager, s.handler)
}
}
@ -1115,7 +1117,7 @@ func (s *Server) Stop() error {
}
logutil.Logger(s.ctx).Info("datacoord compaction stopped")
s.indexBuilder.Stop()
s.taskScheduler.Stop()
logutil.Logger(s.ctx).Info("datacoord index builder stopped")
s.cluster.Close()

View File

@ -0,0 +1,286 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"fmt"
"math"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/exp/slices"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type analyzeTask struct {
taskID int64
nodeID int64
taskInfo *indexpb.AnalyzeResult
}
func (at *analyzeTask) GetTaskID() int64 {
return at.taskID
}
func (at *analyzeTask) GetNodeID() int64 {
return at.nodeID
}
func (at *analyzeTask) ResetNodeID() {
at.nodeID = 0
}
func (at *analyzeTask) CheckTaskHealthy(mt *meta) bool {
t := mt.analyzeMeta.GetTask(at.GetTaskID())
return t != nil
}
func (at *analyzeTask) SetState(state indexpb.JobState, failReason string) {
at.taskInfo.State = state
at.taskInfo.FailReason = failReason
}
func (at *analyzeTask) GetState() indexpb.JobState {
return at.taskInfo.GetState()
}
func (at *analyzeTask) GetFailReason() string {
return at.taskInfo.GetFailReason()
}
func (at *analyzeTask) UpdateVersion(ctx context.Context, meta *meta) error {
return meta.analyzeMeta.UpdateVersion(at.GetTaskID())
}
func (at *analyzeTask) UpdateMetaBuildingState(nodeID int64, meta *meta) error {
if err := meta.analyzeMeta.BuildingTask(at.GetTaskID(), nodeID); err != nil {
return err
}
at.nodeID = nodeID
return nil
}
func (at *analyzeTask) AssignTask(ctx context.Context, client types.IndexNodeClient, dependency *taskScheduler) (bool, bool) {
t := dependency.meta.analyzeMeta.GetTask(at.GetTaskID())
if t == nil {
log.Ctx(ctx).Info("task is nil, delete it", zap.Int64("taskID", at.GetTaskID()))
at.SetState(indexpb.JobState_JobStateNone, "analyze task is nil")
return false, false
}
var storageConfig *indexpb.StorageConfig
if Params.CommonCfg.StorageType.GetValue() == "local" {
storageConfig = &indexpb.StorageConfig{
RootPath: Params.LocalStorageCfg.Path.GetValue(),
StorageType: Params.CommonCfg.StorageType.GetValue(),
}
} else {
storageConfig = &indexpb.StorageConfig{
Address: Params.MinioCfg.Address.GetValue(),
AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(),
SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(),
UseSSL: Params.MinioCfg.UseSSL.GetAsBool(),
BucketName: Params.MinioCfg.BucketName.GetValue(),
RootPath: Params.MinioCfg.RootPath.GetValue(),
UseIAM: Params.MinioCfg.UseIAM.GetAsBool(),
IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(),
StorageType: Params.CommonCfg.StorageType.GetValue(),
Region: Params.MinioCfg.Region.GetValue(),
UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(),
CloudProvider: Params.MinioCfg.CloudProvider.GetValue(),
RequestTimeoutMs: Params.MinioCfg.RequestTimeoutMs.GetAsInt64(),
}
}
req := &indexpb.AnalyzeRequest{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
TaskID: at.GetTaskID(),
CollectionID: t.CollectionID,
PartitionID: t.PartitionID,
FieldID: t.FieldID,
FieldName: t.FieldName,
FieldType: t.FieldType,
Dim: t.Dim,
SegmentStats: make(map[int64]*indexpb.SegmentStats),
Version: t.Version,
StorageConfig: storageConfig,
}
// When data analyze occurs, segments must not be discarded. Such as compaction, GC, etc.
segments := dependency.meta.SelectSegments(SegmentFilterFunc(func(info *SegmentInfo) bool {
return isSegmentHealthy(info) && slices.Contains(t.SegmentIDs, info.ID)
}))
segmentsMap := lo.SliceToMap(segments, func(t *SegmentInfo) (int64, *SegmentInfo) {
return t.ID, t
})
totalSegmentsRows := int64(0)
for _, segID := range t.SegmentIDs {
info := segmentsMap[segID]
if info == nil {
log.Ctx(ctx).Warn("analyze stats task is processing, but segment is nil, delete the task",
zap.Int64("taskID", at.GetTaskID()), zap.Int64("segmentID", segID))
at.SetState(indexpb.JobState_JobStateFailed, fmt.Sprintf("segmentInfo with ID: %d is nil", segID))
return false, false
}
totalSegmentsRows += info.GetNumOfRows()
// get binlogIDs
binlogIDs := getBinLogIDs(info, t.FieldID)
req.SegmentStats[segID] = &indexpb.SegmentStats{
ID: segID,
NumRows: info.GetNumOfRows(),
LogIDs: binlogIDs,
}
}
collInfo, err := dependency.handler.GetCollection(ctx, segments[0].GetCollectionID())
if err != nil {
log.Ctx(ctx).Info("analyze task get collection info failed", zap.Int64("collectionID",
segments[0].GetCollectionID()), zap.Error(err))
at.SetState(indexpb.JobState_JobStateInit, err.Error())
return false, false
}
schema := collInfo.Schema
var field *schemapb.FieldSchema
for _, f := range schema.Fields {
if f.FieldID == t.FieldID {
field = f
break
}
}
dim, err := storage.GetDimFromParams(field.TypeParams)
if err != nil {
at.SetState(indexpb.JobState_JobStateInit, err.Error())
return false, false
}
req.Dim = int64(dim)
totalSegmentsRawDataSize := float64(totalSegmentsRows) * float64(dim) * typeutil.VectorTypeSize(t.FieldType) // Byte
numClusters := int64(math.Ceil(totalSegmentsRawDataSize / float64(Params.DataCoordCfg.ClusteringCompactionPreferSegmentSize.GetAsSize())))
if numClusters < Params.DataCoordCfg.ClusteringCompactionMinCentroidsNum.GetAsInt64() {
log.Ctx(ctx).Info("data size is too small, skip analyze task", zap.Float64("raw data size", totalSegmentsRawDataSize), zap.Int64("num clusters", numClusters), zap.Int64("minimum num clusters required", Params.DataCoordCfg.ClusteringCompactionMinCentroidsNum.GetAsInt64()))
at.SetState(indexpb.JobState_JobStateFinished, "")
return true, true
}
if numClusters > Params.DataCoordCfg.ClusteringCompactionMaxCentroidsNum.GetAsInt64() {
numClusters = Params.DataCoordCfg.ClusteringCompactionMaxCentroidsNum.GetAsInt64()
}
req.NumClusters = numClusters
req.MaxTrainSizeRatio = Params.DataCoordCfg.ClusteringCompactionMaxTrainSizeRatio.GetAsFloat() // control clustering train data size
// config to detect data skewness
req.MinClusterSizeRatio = Params.DataCoordCfg.ClusteringCompactionMinClusterSizeRatio.GetAsFloat()
req.MaxClusterSizeRatio = Params.DataCoordCfg.ClusteringCompactionMaxClusterSizeRatio.GetAsFloat()
req.MaxClusterSize = Params.DataCoordCfg.ClusteringCompactionMaxClusterSize.GetAsSize()
ctx, cancel := context.WithTimeout(context.Background(), reqTimeoutInterval)
defer cancel()
resp, err := client.CreateJobV2(ctx, &indexpb.CreateJobV2Request{
ClusterID: req.GetClusterID(),
TaskID: req.GetTaskID(),
JobType: indexpb.JobType_JobTypeAnalyzeJob,
Request: &indexpb.CreateJobV2Request_AnalyzeRequest{
AnalyzeRequest: req,
},
})
if err == nil {
err = merr.Error(resp)
}
if err != nil {
log.Ctx(ctx).Warn("assign analyze task to indexNode failed", zap.Int64("taskID", at.GetTaskID()), zap.Error(err))
at.SetState(indexpb.JobState_JobStateRetry, err.Error())
return false, true
}
log.Ctx(ctx).Info("analyze task assigned successfully", zap.Int64("taskID", at.GetTaskID()))
at.SetState(indexpb.JobState_JobStateInProgress, "")
return true, false
}
func (at *analyzeTask) setResult(result *indexpb.AnalyzeResult) {
at.taskInfo = result
}
func (at *analyzeTask) QueryResult(ctx context.Context, client types.IndexNodeClient) {
resp, err := client.QueryJobsV2(ctx, &indexpb.QueryJobsV2Request{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
TaskIDs: []int64{at.GetTaskID()},
JobType: indexpb.JobType_JobTypeAnalyzeJob,
})
if err == nil {
err = merr.Error(resp.GetStatus())
}
if err != nil {
log.Ctx(ctx).Warn("query analysis task result from IndexNode fail", zap.Int64("nodeID", at.GetNodeID()),
zap.Error(err))
at.SetState(indexpb.JobState_JobStateRetry, err.Error())
return
}
// infos length is always one.
for _, result := range resp.GetAnalyzeJobResults().GetResults() {
if result.GetTaskID() == at.GetTaskID() {
log.Ctx(ctx).Info("query analysis task info successfully",
zap.Int64("taskID", at.GetTaskID()), zap.String("result state", result.GetState().String()),
zap.String("failReason", result.GetFailReason()))
if result.GetState() == indexpb.JobState_JobStateFinished || result.GetState() == indexpb.JobState_JobStateFailed ||
result.GetState() == indexpb.JobState_JobStateRetry {
// state is retry or finished or failed
at.setResult(result)
} else if result.GetState() == indexpb.JobState_JobStateNone {
at.SetState(indexpb.JobState_JobStateRetry, "analyze task state is none in info response")
}
// inProgress or unissued/init, keep InProgress state
return
}
}
log.Ctx(ctx).Warn("query analyze task info failed, indexNode does not have task info",
zap.Int64("taskID", at.GetTaskID()))
at.SetState(indexpb.JobState_JobStateRetry, "analyze result is not in info response")
}
func (at *analyzeTask) DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool {
resp, err := client.DropJobsV2(ctx, &indexpb.DropJobsV2Request{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
TaskIDs: []UniqueID{at.GetTaskID()},
JobType: indexpb.JobType_JobTypeAnalyzeJob,
})
if err == nil {
err = merr.Error(resp)
}
if err != nil {
log.Ctx(ctx).Warn("notify worker drop the analysis task fail", zap.Int64("taskID", at.GetTaskID()),
zap.Int64("nodeID", at.GetNodeID()), zap.Error(err))
return false
}
log.Ctx(ctx).Info("drop analyze on worker success",
zap.Int64("taskID", at.GetTaskID()), zap.Int64("nodeID", at.GetNodeID()))
return true
}
func (at *analyzeTask) SetJobInfo(meta *meta) error {
return meta.analyzeMeta.FinishTask(at.GetTaskID(), at.taskInfo)
}

View File

@ -0,0 +1,337 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"path"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
itypeutil "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type indexBuildTask struct {
buildID int64
nodeID int64
taskInfo *indexpb.IndexTaskInfo
}
func (it *indexBuildTask) GetTaskID() int64 {
return it.buildID
}
func (it *indexBuildTask) GetNodeID() int64 {
return it.nodeID
}
func (it *indexBuildTask) ResetNodeID() {
it.nodeID = 0
}
func (it *indexBuildTask) CheckTaskHealthy(mt *meta) bool {
_, exist := mt.indexMeta.GetIndexJob(it.GetTaskID())
return exist
}
func (it *indexBuildTask) SetState(state indexpb.JobState, failReason string) {
it.taskInfo.State = commonpb.IndexState(state)
it.taskInfo.FailReason = failReason
}
func (it *indexBuildTask) GetState() indexpb.JobState {
return indexpb.JobState(it.taskInfo.GetState())
}
func (it *indexBuildTask) GetFailReason() string {
return it.taskInfo.FailReason
}
func (it *indexBuildTask) UpdateVersion(ctx context.Context, meta *meta) error {
return meta.indexMeta.UpdateVersion(it.buildID)
}
func (it *indexBuildTask) UpdateMetaBuildingState(nodeID int64, meta *meta) error {
it.nodeID = nodeID
return meta.indexMeta.BuildIndex(it.buildID, nodeID)
}
func (it *indexBuildTask) AssignTask(ctx context.Context, client types.IndexNodeClient, dependency *taskScheduler) (bool, bool) {
segIndex, exist := dependency.meta.indexMeta.GetIndexJob(it.buildID)
if !exist || segIndex == nil {
log.Ctx(ctx).Info("index task has not exist in meta table, remove task", zap.Int64("buildID", it.buildID))
it.SetState(indexpb.JobState_JobStateNone, "index task has not exist in meta table")
return false, false
}
segment := dependency.meta.GetSegment(segIndex.SegmentID)
if !isSegmentHealthy(segment) || !dependency.meta.indexMeta.IsIndexExist(segIndex.CollectionID, segIndex.IndexID) {
log.Ctx(ctx).Info("task is no need to build index, remove it", zap.Int64("buildID", it.buildID))
it.SetState(indexpb.JobState_JobStateNone, "task is no need to build index")
return false, false
}
indexParams := dependency.meta.indexMeta.GetIndexParams(segIndex.CollectionID, segIndex.IndexID)
indexType := GetIndexType(indexParams)
if isFlatIndex(indexType) || segIndex.NumRows < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() {
log.Ctx(ctx).Info("segment does not need index really", zap.Int64("buildID", it.buildID),
zap.Int64("segmentID", segIndex.SegmentID), zap.Int64("num rows", segIndex.NumRows))
it.SetState(indexpb.JobState_JobStateFinished, "fake finished index success")
return true, true
}
// vector index build needs information of optional scalar fields data
optionalFields := make([]*indexpb.OptionalFieldInfo, 0)
if Params.CommonCfg.EnableMaterializedView.GetAsBool() && isOptionalScalarFieldSupported(indexType) {
collInfo, err := dependency.handler.GetCollection(ctx, segIndex.CollectionID)
if err != nil || collInfo == nil {
log.Ctx(ctx).Warn("get collection failed", zap.Int64("collID", segIndex.CollectionID), zap.Error(err))
it.SetState(indexpb.JobState_JobStateInit, err.Error())
return false, false
}
colSchema := collInfo.Schema
partitionKeyField, err := typeutil.GetPartitionKeyFieldSchema(colSchema)
if partitionKeyField == nil || err != nil {
log.Ctx(ctx).Warn("index builder get partition key field failed", zap.Int64("buildID", it.buildID), zap.Error(err))
} else {
optionalFields = append(optionalFields, &indexpb.OptionalFieldInfo{
FieldID: partitionKeyField.FieldID,
FieldName: partitionKeyField.Name,
FieldType: int32(partitionKeyField.DataType),
DataIds: getBinLogIDs(segment, partitionKeyField.FieldID),
})
}
}
typeParams := dependency.meta.indexMeta.GetTypeParams(segIndex.CollectionID, segIndex.IndexID)
var storageConfig *indexpb.StorageConfig
if Params.CommonCfg.StorageType.GetValue() == "local" {
storageConfig = &indexpb.StorageConfig{
RootPath: Params.LocalStorageCfg.Path.GetValue(),
StorageType: Params.CommonCfg.StorageType.GetValue(),
}
} else {
storageConfig = &indexpb.StorageConfig{
Address: Params.MinioCfg.Address.GetValue(),
AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(),
SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(),
UseSSL: Params.MinioCfg.UseSSL.GetAsBool(),
SslCACert: Params.MinioCfg.SslCACert.GetValue(),
BucketName: Params.MinioCfg.BucketName.GetValue(),
RootPath: Params.MinioCfg.RootPath.GetValue(),
UseIAM: Params.MinioCfg.UseIAM.GetAsBool(),
IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(),
StorageType: Params.CommonCfg.StorageType.GetValue(),
Region: Params.MinioCfg.Region.GetValue(),
UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(),
CloudProvider: Params.MinioCfg.CloudProvider.GetValue(),
RequestTimeoutMs: Params.MinioCfg.RequestTimeoutMs.GetAsInt64(),
}
}
fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID)
binlogIDs := getBinLogIDs(segment, fieldID)
if isDiskANNIndex(GetIndexType(indexParams)) {
var err error
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)
if err != nil {
log.Ctx(ctx).Warn("failed to append index build params", zap.Int64("buildID", it.buildID), zap.Error(err))
it.SetState(indexpb.JobState_JobStateInit, err.Error())
return false, false
}
}
var req *indexpb.CreateJobRequest
collectionInfo, err := dependency.handler.GetCollection(ctx, segment.GetCollectionID())
if err != nil {
log.Ctx(ctx).Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err))
return false, false
}
schema := collectionInfo.Schema
var field *schemapb.FieldSchema
for _, f := range schema.Fields {
if f.FieldID == fieldID {
field = f
break
}
}
dim, err := storage.GetDimFromParams(field.TypeParams)
if err != nil {
log.Ctx(ctx).Warn("failed to get dim from field type params",
zap.String("field type", field.GetDataType().String()), zap.Error(err))
// don't return, maybe field is scalar field or sparseFloatVector
}
if Params.CommonCfg.EnableStorageV2.GetAsBool() {
storePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue(), segment.GetID())
if err != nil {
log.Ctx(ctx).Warn("failed to get storage uri", zap.Error(err))
it.SetState(indexpb.JobState_JobStateInit, err.Error())
return false, false
}
indexStorePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue()+"/index", segment.GetID())
if err != nil {
log.Ctx(ctx).Warn("failed to get storage uri", zap.Error(err))
it.SetState(indexpb.JobState_JobStateInit, err.Error())
return false, false
}
req = &indexpb.CreateJobRequest{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
IndexFilePrefix: path.Join(dependency.chunkManager.RootPath(), common.SegmentIndexPath),
BuildID: it.buildID,
IndexVersion: segIndex.IndexVersion,
StorageConfig: storageConfig,
IndexParams: indexParams,
TypeParams: typeParams,
NumRows: segIndex.NumRows,
CollectionID: segment.GetCollectionID(),
PartitionID: segment.GetPartitionID(),
SegmentID: segment.GetID(),
FieldID: fieldID,
FieldName: field.Name,
FieldType: field.DataType,
StorePath: storePath,
StoreVersion: segment.GetStorageVersion(),
IndexStorePath: indexStorePath,
Dim: int64(dim),
CurrentIndexVersion: dependency.indexEngineVersionManager.GetCurrentIndexEngineVersion(),
DataIds: binlogIDs,
OptionalScalarFields: optionalFields,
Field: field,
}
} else {
req = &indexpb.CreateJobRequest{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
IndexFilePrefix: path.Join(dependency.chunkManager.RootPath(), common.SegmentIndexPath),
BuildID: it.buildID,
IndexVersion: segIndex.IndexVersion,
StorageConfig: storageConfig,
IndexParams: indexParams,
TypeParams: typeParams,
NumRows: segIndex.NumRows,
CurrentIndexVersion: dependency.indexEngineVersionManager.GetCurrentIndexEngineVersion(),
DataIds: binlogIDs,
CollectionID: segment.GetCollectionID(),
PartitionID: segment.GetPartitionID(),
SegmentID: segment.GetID(),
FieldID: fieldID,
OptionalScalarFields: optionalFields,
Dim: int64(dim),
Field: field,
}
}
ctx, cancel := context.WithTimeout(context.Background(), reqTimeoutInterval)
defer cancel()
resp, err := client.CreateJobV2(ctx, &indexpb.CreateJobV2Request{
ClusterID: req.GetClusterID(),
TaskID: req.GetBuildID(),
JobType: indexpb.JobType_JobTypeIndexJob,
Request: &indexpb.CreateJobV2Request_IndexRequest{
IndexRequest: req,
},
})
if err == nil {
err = merr.Error(resp)
}
if err != nil {
log.Ctx(ctx).Warn("assign index task to indexNode failed", zap.Int64("buildID", it.buildID), zap.Error(err))
it.SetState(indexpb.JobState_JobStateRetry, err.Error())
return false, true
}
log.Ctx(ctx).Info("index task assigned successfully", zap.Int64("buildID", it.buildID),
zap.Int64("segmentID", segIndex.SegmentID))
it.SetState(indexpb.JobState_JobStateInProgress, "")
return true, false
}
func (it *indexBuildTask) setResult(info *indexpb.IndexTaskInfo) {
it.taskInfo = info
}
func (it *indexBuildTask) QueryResult(ctx context.Context, node types.IndexNodeClient) {
resp, err := node.QueryJobsV2(ctx, &indexpb.QueryJobsV2Request{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
TaskIDs: []UniqueID{it.GetTaskID()},
JobType: indexpb.JobType_JobTypeIndexJob,
})
if err == nil {
err = merr.Error(resp.GetStatus())
}
if err != nil {
log.Ctx(ctx).Warn("get jobs info from IndexNode failed", zap.Int64("buildID", it.GetTaskID()),
zap.Int64("nodeID", it.GetNodeID()), zap.Error(err))
it.SetState(indexpb.JobState_JobStateRetry, err.Error())
return
}
// indexInfos length is always one.
for _, info := range resp.GetIndexJobResults().GetResults() {
if info.GetBuildID() == it.GetTaskID() {
log.Ctx(ctx).Info("query task index info successfully",
zap.Int64("taskID", it.GetTaskID()), zap.String("result state", info.GetState().String()),
zap.String("failReason", info.GetFailReason()))
if info.GetState() == commonpb.IndexState_Finished || info.GetState() == commonpb.IndexState_Failed ||
info.GetState() == commonpb.IndexState_Retry {
// state is retry or finished or failed
it.setResult(info)
} else if info.GetState() == commonpb.IndexState_IndexStateNone {
it.SetState(indexpb.JobState_JobStateRetry, "index state is none in info response")
}
// inProgress or unissued, keep InProgress state
return
}
}
it.SetState(indexpb.JobState_JobStateRetry, "index is not in info response")
}
func (it *indexBuildTask) DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool {
resp, err := client.DropJobsV2(ctx, &indexpb.DropJobsV2Request{
ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(),
TaskIDs: []UniqueID{it.GetTaskID()},
JobType: indexpb.JobType_JobTypeIndexJob,
})
if err == nil {
err = merr.Error(resp)
}
if err != nil {
log.Ctx(ctx).Warn("notify worker drop the index task fail", zap.Int64("taskID", it.GetTaskID()),
zap.Int64("nodeID", it.GetNodeID()), zap.Error(err))
return false
}
log.Ctx(ctx).Info("drop index task on worker success", zap.Int64("taskID", it.GetTaskID()),
zap.Int64("nodeID", it.GetNodeID()))
return true
}
func (it *indexBuildTask) SetJobInfo(meta *meta) error {
return meta.indexMeta.FinishTask(it.taskInfo)
}

View File

@ -0,0 +1,296 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log"
)
const (
reqTimeoutInterval = time.Second * 10
)
type taskScheduler struct {
sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
scheduleDuration time.Duration
// TODO @xiaocai2333: use priority queue
tasks map[int64]Task
notifyChan chan struct{}
meta *meta
policy buildIndexPolicy
nodeManager WorkerManager
chunkManager storage.ChunkManager
indexEngineVersionManager IndexEngineVersionManager
handler Handler
}
func newTaskScheduler(
ctx context.Context,
metaTable *meta, nodeManager WorkerManager,
chunkManager storage.ChunkManager,
indexEngineVersionManager IndexEngineVersionManager,
handler Handler,
) *taskScheduler {
ctx, cancel := context.WithCancel(ctx)
ts := &taskScheduler{
ctx: ctx,
cancel: cancel,
meta: metaTable,
tasks: make(map[int64]Task),
notifyChan: make(chan struct{}, 1),
scheduleDuration: Params.DataCoordCfg.IndexTaskSchedulerInterval.GetAsDuration(time.Millisecond),
policy: defaultBuildIndexPolicy,
nodeManager: nodeManager,
chunkManager: chunkManager,
handler: handler,
indexEngineVersionManager: indexEngineVersionManager,
}
ts.reloadFromKV()
return ts
}
func (s *taskScheduler) Start() {
s.wg.Add(1)
go s.schedule()
}
func (s *taskScheduler) Stop() {
s.cancel()
s.wg.Wait()
}
func (s *taskScheduler) reloadFromKV() {
segments := s.meta.GetAllSegmentsUnsafe()
for _, segment := range segments {
for _, segIndex := range s.meta.indexMeta.getSegmentIndexes(segment.ID) {
if segIndex.IsDeleted {
continue
}
if segIndex.IndexState != commonpb.IndexState_Finished && segIndex.IndexState != commonpb.IndexState_Failed {
s.tasks[segIndex.BuildID] = &indexBuildTask{
buildID: segIndex.BuildID,
nodeID: segIndex.NodeID,
taskInfo: &indexpb.IndexTaskInfo{
BuildID: segIndex.BuildID,
State: segIndex.IndexState,
FailReason: segIndex.FailReason,
},
}
}
}
}
allAnalyzeTasks := s.meta.analyzeMeta.GetAllTasks()
for taskID, t := range allAnalyzeTasks {
if t.State != indexpb.JobState_JobStateFinished && t.State != indexpb.JobState_JobStateFailed {
s.tasks[taskID] = &analyzeTask{
taskID: taskID,
nodeID: t.NodeID,
taskInfo: &indexpb.AnalyzeResult{
TaskID: taskID,
State: t.State,
FailReason: t.FailReason,
},
}
}
}
}
// notify is an unblocked notify function
func (s *taskScheduler) notify() {
select {
case s.notifyChan <- struct{}{}:
default:
}
}
func (s *taskScheduler) enqueue(task Task) {
defer s.notify()
s.Lock()
defer s.Unlock()
taskID := task.GetTaskID()
if _, ok := s.tasks[taskID]; !ok {
s.tasks[taskID] = task
}
log.Info("taskScheduler enqueue task", zap.Int64("taskID", taskID))
}
func (s *taskScheduler) schedule() {
// receive notifyChan
// time ticker
log.Ctx(s.ctx).Info("task scheduler loop start")
defer s.wg.Done()
ticker := time.NewTicker(s.scheduleDuration)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
log.Ctx(s.ctx).Warn("task scheduler ctx done")
return
case _, ok := <-s.notifyChan:
if ok {
s.run()
}
// !ok means indexBuild is closed.
case <-ticker.C:
s.run()
}
}
}
func (s *taskScheduler) getTask(taskID UniqueID) Task {
s.RLock()
defer s.RUnlock()
return s.tasks[taskID]
}
func (s *taskScheduler) run() {
// schedule policy
s.RLock()
taskIDs := make([]UniqueID, 0, len(s.tasks))
for tID := range s.tasks {
taskIDs = append(taskIDs, tID)
}
s.RUnlock()
if len(taskIDs) > 0 {
log.Ctx(s.ctx).Info("task scheduler", zap.Int("task num", len(taskIDs)))
}
s.policy(taskIDs)
for _, taskID := range taskIDs {
ok := s.process(taskID)
if !ok {
log.Ctx(s.ctx).Info("there is no idle indexing node, wait a minute...")
break
}
}
}
func (s *taskScheduler) removeTask(taskID UniqueID) {
s.Lock()
defer s.Unlock()
delete(s.tasks, taskID)
}
func (s *taskScheduler) process(taskID UniqueID) bool {
task := s.getTask(taskID)
if !task.CheckTaskHealthy(s.meta) {
s.removeTask(taskID)
return true
}
state := task.GetState()
log.Ctx(s.ctx).Info("task is processing", zap.Int64("taskID", taskID),
zap.String("state", state.String()))
switch state {
case indexpb.JobState_JobStateNone:
s.removeTask(taskID)
case indexpb.JobState_JobStateInit:
// 1. pick an indexNode client
nodeID, client := s.nodeManager.PickClient()
if client == nil {
log.Ctx(s.ctx).Debug("pick client failed")
return false
}
log.Ctx(s.ctx).Info("pick client success", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID))
// 2. update version
if err := task.UpdateVersion(s.ctx, s.meta); err != nil {
log.Ctx(s.ctx).Warn("update task version failed", zap.Int64("taskID", taskID), zap.Error(err))
return false
}
log.Ctx(s.ctx).Info("update task version success", zap.Int64("taskID", taskID))
// 3. assign task to indexNode
success, skip := task.AssignTask(s.ctx, client, s)
if !success {
log.Ctx(s.ctx).Warn("assign task to client failed", zap.Int64("taskID", taskID),
zap.String("new state", task.GetState().String()), zap.String("fail reason", task.GetFailReason()))
// If the problem is caused by the task itself, subsequent tasks will not be skipped.
// If etcd fails or fails to send tasks to the node, the subsequent tasks will be skipped.
return !skip
}
if skip {
// create index for small segment(<1024), skip next steps.
return true
}
log.Ctx(s.ctx).Info("assign task to client success", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID))
// 4. update meta state
if err := task.UpdateMetaBuildingState(nodeID, s.meta); err != nil {
log.Ctx(s.ctx).Warn("update meta building state failed", zap.Int64("taskID", taskID), zap.Error(err))
task.SetState(indexpb.JobState_JobStateRetry, "update meta building state failed")
return false
}
log.Ctx(s.ctx).Info("update task meta state to InProgress success", zap.Int64("taskID", taskID),
zap.Int64("nodeID", nodeID))
case indexpb.JobState_JobStateFinished, indexpb.JobState_JobStateFailed:
if err := task.SetJobInfo(s.meta); err != nil {
log.Ctx(s.ctx).Warn("update task info failed", zap.Error(err))
return true
}
client, exist := s.nodeManager.GetClientByID(task.GetNodeID())
if exist {
if !task.DropTaskOnWorker(s.ctx, client) {
return true
}
}
s.removeTask(taskID)
case indexpb.JobState_JobStateRetry:
client, exist := s.nodeManager.GetClientByID(task.GetNodeID())
if exist {
if !task.DropTaskOnWorker(s.ctx, client) {
return true
}
}
task.SetState(indexpb.JobState_JobStateInit, "")
task.ResetNodeID()
default:
// state: in_progress
client, exist := s.nodeManager.GetClientByID(task.GetNodeID())
if exist {
task.QueryResult(s.ctx, client)
return true
}
task.SetState(indexpb.JobState_JobStateRetry, "")
}
return true
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,40 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types"
)
type Task interface {
GetTaskID() int64
GetNodeID() int64
ResetNodeID()
CheckTaskHealthy(mt *meta) bool
SetState(state indexpb.JobState, failReason string)
GetState() indexpb.JobState
GetFailReason() string
UpdateVersion(ctx context.Context, meta *meta) error
UpdateMetaBuildingState(nodeID int64, meta *meta) error
AssignTask(ctx context.Context, client types.IndexNodeClient, dependency *taskScheduler) (bool, bool)
QueryResult(ctx context.Context, client types.IndexNodeClient)
DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool
SetJobInfo(meta *meta) error
}

View File

@ -196,6 +196,10 @@ func isFlatIndex(indexType string) bool {
return indexType == indexparamcheck.IndexFaissIDMap || indexType == indexparamcheck.IndexFaissBinIDMap
}
func isOptionalScalarFieldSupported(indexType string) bool {
return indexType == indexparamcheck.IndexHNSW
}
func isDiskANNIndex(indexType string) bool {
return indexType == indexparamcheck.IndexDISKANN
}
@ -256,3 +260,16 @@ func getCompactionMergeInfo(task *datapb.CompactionTask) *milvuspb.CompactionMer
Target: target,
}
}
func getBinLogIDs(segment *SegmentInfo, fieldID int64) []int64 {
binlogIDs := make([]int64, 0)
for _, fieldBinLog := range segment.GetBinlogs() {
if fieldBinLog.GetFieldID() == fieldID {
for _, binLog := range fieldBinLog.GetBinlogs() {
binlogIDs = append(binlogIDs, binLog.GetLogID())
}
break
}
}
return binlogIDs
}

View File

@ -163,3 +163,21 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest
return client.GetMetrics(ctx, req)
})
}
func (c *Client) CreateJobV2(ctx context.Context, req *indexpb.CreateJobV2Request, opts ...grpc.CallOption) (*commonpb.Status, error) {
return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*commonpb.Status, error) {
return client.CreateJobV2(ctx, req)
})
}
func (c *Client) QueryJobsV2(ctx context.Context, req *indexpb.QueryJobsV2Request, opts ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) {
return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*indexpb.QueryJobsV2Response, error) {
return client.QueryJobsV2(ctx, req)
})
}
func (c *Client) DropJobsV2(ctx context.Context, req *indexpb.DropJobsV2Request, opt ...grpc.CallOption) (*commonpb.Status, error) {
return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*commonpb.Status, error) {
return client.DropJobsV2(ctx, req)
})
}

View File

@ -164,6 +164,24 @@ func TestIndexNodeClient(t *testing.T) {
assert.NoError(t, err)
})
t.Run("CreateJobV2", func(t *testing.T) {
req := &indexpb.CreateJobV2Request{}
_, err := inc.CreateJobV2(ctx, req)
assert.NoError(t, err)
})
t.Run("QueryJobsV2", func(t *testing.T) {
req := &indexpb.QueryJobsV2Request{}
_, err := inc.QueryJobsV2(ctx, req)
assert.NoError(t, err)
})
t.Run("DropJobsV2", func(t *testing.T) {
req := &indexpb.DropJobsV2Request{}
_, err := inc.DropJobsV2(ctx, req)
assert.NoError(t, err)
})
err := inc.Close()
assert.NoError(t, err)
}

View File

@ -289,6 +289,18 @@ func (s *Server) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsReq
return s.indexnode.GetMetrics(ctx, request)
}
func (s *Server) CreateJobV2(ctx context.Context, request *indexpb.CreateJobV2Request) (*commonpb.Status, error) {
return s.indexnode.CreateJobV2(ctx, request)
}
func (s *Server) QueryJobsV2(ctx context.Context, request *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) {
return s.indexnode.QueryJobsV2(ctx, request)
}
func (s *Server) DropJobsV2(ctx context.Context, request *indexpb.DropJobsV2Request) (*commonpb.Status, error) {
return s.indexnode.DropJobsV2(ctx, request)
}
// NewServer create a new IndexNode grpc server.
func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) {
ctx1, cancel := context.WithCancel(ctx)

View File

@ -21,13 +21,15 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -40,7 +42,13 @@ func TestIndexNodeServer(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, server)
inm := indexnode.NewIndexNodeMock()
inm := mocks.NewMockIndexNode(t)
inm.EXPECT().SetEtcdClient(mock.Anything).Return()
inm.EXPECT().SetAddress(mock.Anything).Return()
inm.EXPECT().Start().Return(nil)
inm.EXPECT().Init().Return(nil)
inm.EXPECT().Register().Return(nil)
inm.EXPECT().Stop().Return(nil)
err = server.setServer(inm)
assert.NoError(t, err)
@ -48,6 +56,11 @@ func TestIndexNodeServer(t *testing.T) {
assert.NoError(t, err)
t.Run("GetComponentStates", func(t *testing.T) {
inm.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
}, nil)
req := &milvuspb.GetComponentStatesRequest{}
states, err := server.GetComponentStates(ctx, req)
assert.NoError(t, err)
@ -55,6 +68,9 @@ func TestIndexNodeServer(t *testing.T) {
})
t.Run("GetStatisticsChannel", func(t *testing.T) {
inm.EXPECT().GetStatisticsChannel(mock.Anything, mock.Anything).Return(&milvuspb.StringResponse{
Status: merr.Success(),
}, nil)
req := &internalpb.GetStatisticsChannelRequest{}
resp, err := server.GetStatisticsChannel(ctx, req)
assert.NoError(t, err)
@ -62,6 +78,7 @@ func TestIndexNodeServer(t *testing.T) {
})
t.Run("CreateJob", func(t *testing.T) {
inm.EXPECT().CreateJob(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req := &indexpb.CreateJobRequest{
ClusterID: "",
BuildID: 0,
@ -74,6 +91,9 @@ func TestIndexNodeServer(t *testing.T) {
})
t.Run("QueryJob", func(t *testing.T) {
inm.EXPECT().QueryJobs(mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{
Status: merr.Success(),
}, nil)
req := &indexpb.QueryJobsRequest{}
resp, err := server.QueryJobs(ctx, req)
assert.NoError(t, err)
@ -81,6 +101,7 @@ func TestIndexNodeServer(t *testing.T) {
})
t.Run("DropJobs", func(t *testing.T) {
inm.EXPECT().DropJobs(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req := &indexpb.DropJobsRequest{}
resp, err := server.DropJobs(ctx, req)
assert.NoError(t, err)
@ -88,6 +109,9 @@ func TestIndexNodeServer(t *testing.T) {
})
t.Run("ShowConfigurations", func(t *testing.T) {
inm.EXPECT().ShowConfigurations(mock.Anything, mock.Anything).Return(&internalpb.ShowConfigurationsResponse{
Status: merr.Success(),
}, nil)
req := &internalpb.ShowConfigurationsRequest{
Pattern: "",
}
@ -97,6 +121,9 @@ func TestIndexNodeServer(t *testing.T) {
})
t.Run("GetMetrics", func(t *testing.T) {
inm.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{
Status: merr.Success(),
}, nil)
req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics)
assert.NoError(t, err)
resp, err := server.GetMetrics(ctx, req)
@ -105,12 +132,41 @@ func TestIndexNodeServer(t *testing.T) {
})
t.Run("GetTaskSlots", func(t *testing.T) {
inm.EXPECT().GetJobStats(mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{
Status: merr.Success(),
}, nil)
req := &indexpb.GetJobStatsRequest{}
resp, err := server.GetJobStats(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("CreateJobV2", func(t *testing.T) {
inm.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req := &indexpb.CreateJobV2Request{}
resp, err := server.CreateJobV2(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("QueryJobsV2", func(t *testing.T) {
inm.EXPECT().QueryJobsV2(mock.Anything, mock.Anything).Return(&indexpb.QueryJobsV2Response{
Status: merr.Success(),
}, nil)
req := &indexpb.QueryJobsV2Request{}
resp, err := server.QueryJobsV2(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
})
t.Run("DropJobsV2", func(t *testing.T) {
inm.EXPECT().DropJobsV2(mock.Anything, mock.Anything).Return(merr.Success(), nil)
req := &indexpb.DropJobsV2Request{}
resp, err := server.DropJobsV2(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
err = server.Stop()
assert.NoError(t, err)
}

View File

@ -17,7 +17,7 @@
package indexnode
/*
#cgo pkg-config: milvus_common milvus_indexbuilder milvus_segcore
#cgo pkg-config: milvus_common milvus_indexbuilder milvus_clustering milvus_segcore
#include <stdlib.h>
#include <stdint.h>
@ -105,9 +105,10 @@ type IndexNode struct {
etcdCli *clientv3.Client
address string
initOnce sync.Once
stateLock sync.Mutex
tasks map[taskKey]*taskInfo
initOnce sync.Once
stateLock sync.Mutex
indexTasks map[taskKey]*indexTaskInfo
analyzeTasks map[taskKey]*analyzeTaskInfo
}
// NewIndexNode creates a new IndexNode component.
@ -120,7 +121,8 @@ func NewIndexNode(ctx context.Context, factory dependency.Factory) *IndexNode {
loopCancel: cancel,
factory: factory,
storageFactory: NewChunkMgrFactory(),
tasks: map[taskKey]*taskInfo{},
indexTasks: make(map[taskKey]*indexTaskInfo),
analyzeTasks: make(map[taskKey]*analyzeTaskInfo),
lifetime: lifetime.NewLifetime(commonpb.StateCode_Abnormal),
}
sc := NewTaskScheduler(b.loopCtx)
@ -251,10 +253,16 @@ func (i *IndexNode) Stop() error {
i.lifetime.Wait()
log.Info("Index node abnormal")
// cleanup all running tasks
deletedTasks := i.deleteAllTasks()
for _, task := range deletedTasks {
if task.cancel != nil {
task.cancel()
deletedIndexTasks := i.deleteAllIndexTasks()
for _, t := range deletedIndexTasks {
if t.cancel != nil {
t.cancel()
}
}
deletedAnalyzeTasks := i.deleteAllAnalyzeTasks()
for _, t := range deletedAnalyzeTasks {
if t.cancel != nil {
t.cancel()
}
}
if i.sched != nil {

View File

@ -18,7 +18,9 @@ package indexnode
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -52,6 +54,9 @@ type Mock struct {
CallQueryJobs func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error)
CallDropJobs func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error)
CallGetJobStats func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error)
CallCreateJobV2 func(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error)
CallQueryJobV2 func(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error)
CallDropJobV2 func(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error)
CallGetMetrics func(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)
CallShowConfigurations func(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)
@ -114,6 +119,62 @@ func NewIndexNodeMock() *Mock {
CallDropJobs: func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error) {
return merr.Success(), nil
},
CallCreateJobV2: func(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error) {
return merr.Success(), nil
},
CallQueryJobV2: func(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) {
switch req.GetJobType() {
case indexpb.JobType_JobTypeIndexJob:
results := make([]*indexpb.IndexTaskInfo, 0)
for _, buildID := range req.GetTaskIDs() {
results = append(results, &indexpb.IndexTaskInfo{
BuildID: buildID,
State: commonpb.IndexState_Finished,
IndexFileKeys: []string{},
SerializedSize: 1024,
FailReason: "",
CurrentIndexVersion: 1,
IndexStoreVersion: 1,
})
}
return &indexpb.QueryJobsV2Response{
Status: merr.Success(),
ClusterID: req.GetClusterID(),
Result: &indexpb.QueryJobsV2Response_IndexJobResults{
IndexJobResults: &indexpb.IndexJobResults{
Results: results,
},
},
}, nil
case indexpb.JobType_JobTypeAnalyzeJob:
results := make([]*indexpb.AnalyzeResult, 0)
for _, taskID := range req.GetTaskIDs() {
results = append(results, &indexpb.AnalyzeResult{
TaskID: taskID,
State: indexpb.JobState_JobStateFinished,
CentroidsFile: fmt.Sprintf("%d/stats_file", taskID),
FailReason: "",
})
}
return &indexpb.QueryJobsV2Response{
Status: merr.Success(),
ClusterID: req.GetClusterID(),
Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{
AnalyzeJobResults: &indexpb.AnalyzeResults{
Results: results,
},
},
}, nil
default:
return &indexpb.QueryJobsV2Response{
Status: merr.Status(errors.New("unknown job type")),
ClusterID: req.GetClusterID(),
}, nil
}
},
CallDropJobV2: func(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error) {
return merr.Success(), nil
},
CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) {
return &indexpb.GetJobStatsResponse{
Status: merr.Success(),
@ -201,6 +262,18 @@ func (m *Mock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest)
return m.CallGetMetrics(ctx, req)
}
func (m *Mock) CreateJobV2(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error) {
return m.CallCreateJobV2(ctx, req)
}
func (m *Mock) QueryJobsV2(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) {
return m.CallQueryJobV2(ctx, req)
}
func (m *Mock) DropJobsV2(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error) {
return m.CallDropJobV2(ctx, req)
}
// ShowConfigurations returns the configurations of Mock indexNode matching req.Pattern
func (m *Mock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
return m.CallShowConfigurations(ctx, req)

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -77,7 +78,7 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest
metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.TotalLabel).Inc()
taskCtx, taskCancel := context.WithCancel(i.loopCtx)
if oldInfo := i.loadOrStoreTask(req.GetClusterID(), req.GetBuildID(), &taskInfo{
if oldInfo := i.loadOrStoreIndexTask(req.GetClusterID(), req.GetBuildID(), &indexTaskInfo{
cancel: taskCancel,
state: commonpb.IndexState_InProgress,
}); oldInfo != nil {
@ -92,7 +93,7 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest
zap.String("accessKey", req.GetStorageConfig().GetAccessKeyID()),
zap.Error(err),
)
i.deleteTaskInfos(ctx, []taskKey{{ClusterID: req.GetClusterID(), BuildID: req.GetBuildID()}})
i.deleteIndexTaskInfos(ctx, []taskKey{{ClusterID: req.GetClusterID(), BuildID: req.GetBuildID()}})
metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc()
return merr.Status(err), nil
}
@ -103,7 +104,7 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest
task = newIndexBuildTask(taskCtx, taskCancel, req, cm, i)
}
ret := merr.Success()
if err := i.sched.IndexBuildQueue.Enqueue(task); err != nil {
if err := i.sched.TaskQueue.Enqueue(task); err != nil {
log.Warn("IndexNode failed to schedule",
zap.Error(err))
ret = merr.Status(err)
@ -127,10 +128,10 @@ func (i *IndexNode) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest
}, nil
}
defer i.lifetime.Done()
infos := make(map[UniqueID]*taskInfo)
i.foreachTaskInfo(func(ClusterID string, buildID UniqueID, info *taskInfo) {
infos := make(map[UniqueID]*indexTaskInfo)
i.foreachIndexTaskInfo(func(ClusterID string, buildID UniqueID, info *indexTaskInfo) {
if ClusterID == req.GetClusterID() {
infos[buildID] = &taskInfo{
infos[buildID] = &indexTaskInfo{
state: info.state,
fileKeys: common.CloneStringList(info.fileKeys),
serializedSize: info.serializedSize,
@ -183,7 +184,7 @@ func (i *IndexNode) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest)
for _, buildID := range req.GetBuildIDs() {
keys = append(keys, taskKey{ClusterID: req.GetClusterID(), BuildID: buildID})
}
infos := i.deleteTaskInfos(ctx, keys)
infos := i.deleteIndexTaskInfos(ctx, keys)
for _, info := range infos {
if info.cancel != nil {
info.cancel()
@ -203,7 +204,8 @@ func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsReq
}, nil
}
defer i.lifetime.Done()
unissued, active := i.sched.IndexBuildQueue.GetTaskNum()
unissued, active := i.sched.TaskQueue.GetTaskNum()
slots := 0
if i.sched.buildParallel > unissued+active {
slots = i.sched.buildParallel - unissued - active
@ -271,3 +273,250 @@ func (i *IndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequ
Status: merr.Status(merr.WrapErrMetricNotFound(metricType)),
}, nil
}
func (i *IndexNode) CreateJobV2(ctx context.Context, req *indexpb.CreateJobV2Request) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(
zap.String("clusterID", req.GetClusterID()), zap.Int64("taskID", req.GetTaskID()),
zap.String("jobType", req.GetJobType().String()),
)
if err := i.lifetime.Add(merr.IsHealthy); err != nil {
log.Warn("index node not ready",
zap.Error(err),
)
return merr.Status(err), nil
}
defer i.lifetime.Done()
log.Info("IndexNode receive CreateJob request...")
switch req.GetJobType() {
case indexpb.JobType_JobTypeIndexJob:
indexRequest := req.GetIndexRequest()
log.Info("IndexNode building index ...",
zap.Int64("indexID", indexRequest.GetIndexID()),
zap.String("indexName", indexRequest.GetIndexName()),
zap.String("indexFilePrefix", indexRequest.GetIndexFilePrefix()),
zap.Int64("indexVersion", indexRequest.GetIndexVersion()),
zap.Strings("dataPaths", indexRequest.GetDataPaths()),
zap.Any("typeParams", indexRequest.GetTypeParams()),
zap.Any("indexParams", indexRequest.GetIndexParams()),
zap.Int64("numRows", indexRequest.GetNumRows()),
zap.Int32("current_index_version", indexRequest.GetCurrentIndexVersion()),
zap.String("storePath", indexRequest.GetStorePath()),
zap.Int64("storeVersion", indexRequest.GetStoreVersion()),
zap.String("indexStorePath", indexRequest.GetIndexStorePath()),
zap.Int64("dim", indexRequest.GetDim()))
taskCtx, taskCancel := context.WithCancel(i.loopCtx)
if oldInfo := i.loadOrStoreIndexTask(indexRequest.GetClusterID(), indexRequest.GetBuildID(), &indexTaskInfo{
cancel: taskCancel,
state: commonpb.IndexState_InProgress,
}); oldInfo != nil {
err := merr.WrapErrIndexDuplicate(indexRequest.GetIndexName(), "building index task existed")
log.Warn("duplicated index build task", zap.Error(err))
metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc()
return merr.Status(err), nil
}
cm, err := i.storageFactory.NewChunkManager(i.loopCtx, indexRequest.GetStorageConfig())
if err != nil {
log.Error("create chunk manager failed", zap.String("bucket", indexRequest.GetStorageConfig().GetBucketName()),
zap.String("accessKey", indexRequest.GetStorageConfig().GetAccessKeyID()),
zap.Error(err),
)
i.deleteIndexTaskInfos(ctx, []taskKey{{ClusterID: indexRequest.GetClusterID(), BuildID: indexRequest.GetBuildID()}})
metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc()
return merr.Status(err), nil
}
var task task
if Params.CommonCfg.EnableStorageV2.GetAsBool() {
task = newIndexBuildTaskV2(taskCtx, taskCancel, indexRequest, i)
} else {
task = newIndexBuildTask(taskCtx, taskCancel, indexRequest, cm, i)
}
ret := merr.Success()
if err := i.sched.TaskQueue.Enqueue(task); err != nil {
log.Warn("IndexNode failed to schedule",
zap.Error(err))
ret = merr.Status(err)
metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.FailLabel).Inc()
return ret, nil
}
metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SuccessLabel).Inc()
log.Info("IndexNode index job enqueued successfully",
zap.String("indexName", indexRequest.GetIndexName()))
return ret, nil
case indexpb.JobType_JobTypeAnalyzeJob:
analyzeRequest := req.GetAnalyzeRequest()
log.Info("receive analyze job", zap.Int64("collectionID", analyzeRequest.GetCollectionID()),
zap.Int64("partitionID", analyzeRequest.GetPartitionID()),
zap.Int64("fieldID", analyzeRequest.GetFieldID()),
zap.String("fieldName", analyzeRequest.GetFieldName()),
zap.String("dataType", analyzeRequest.GetFieldType().String()),
zap.Int64("version", analyzeRequest.GetVersion()),
zap.Int64("dim", analyzeRequest.GetDim()),
zap.Float64("trainSizeRatio", analyzeRequest.GetMaxTrainSizeRatio()),
zap.Int64("numClusters", analyzeRequest.GetNumClusters()),
)
taskCtx, taskCancel := context.WithCancel(i.loopCtx)
if oldInfo := i.loadOrStoreAnalyzeTask(analyzeRequest.GetClusterID(), analyzeRequest.GetTaskID(), &analyzeTaskInfo{
cancel: taskCancel,
state: indexpb.JobState_JobStateInProgress,
}); oldInfo != nil {
err := merr.WrapErrIndexDuplicate("", "analyze task already existed")
log.Warn("duplicated analyze task", zap.Error(err))
return merr.Status(err), nil
}
t := &analyzeTask{
ident: fmt.Sprintf("%s/%d", analyzeRequest.GetClusterID(), analyzeRequest.GetTaskID()),
ctx: taskCtx,
cancel: taskCancel,
req: analyzeRequest,
node: i,
tr: timerecord.NewTimeRecorder(fmt.Sprintf("ClusterID: %s, IndexBuildID: %d", req.GetClusterID(), req.GetTaskID())),
}
ret := merr.Success()
if err := i.sched.TaskQueue.Enqueue(t); err != nil {
log.Warn("IndexNode failed to schedule", zap.Error(err))
ret = merr.Status(err)
return ret, nil
}
log.Info("IndexNode analyze job enqueued successfully")
return ret, nil
default:
log.Warn("IndexNode receive unknown type job")
return merr.Status(fmt.Errorf("IndexNode receive unknown type job with taskID: %d", req.GetTaskID())), nil
}
}
func (i *IndexNode) QueryJobsV2(ctx context.Context, req *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) {
log := log.Ctx(ctx).With(
zap.String("clusterID", req.GetClusterID()), zap.Int64s("taskIDs", req.GetTaskIDs()),
).WithRateGroup("QueryResult", 1, 60)
if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil {
log.Warn("IndexNode not ready", zap.Error(err))
return &indexpb.QueryJobsV2Response{
Status: merr.Status(err),
}, nil
}
defer i.lifetime.Done()
switch req.GetJobType() {
case indexpb.JobType_JobTypeIndexJob:
infos := make(map[UniqueID]*indexTaskInfo)
i.foreachIndexTaskInfo(func(ClusterID string, buildID UniqueID, info *indexTaskInfo) {
if ClusterID == req.GetClusterID() {
infos[buildID] = &indexTaskInfo{
state: info.state,
fileKeys: common.CloneStringList(info.fileKeys),
serializedSize: info.serializedSize,
failReason: info.failReason,
currentIndexVersion: info.currentIndexVersion,
indexStoreVersion: info.indexStoreVersion,
}
}
})
results := make([]*indexpb.IndexTaskInfo, 0, len(req.GetTaskIDs()))
for i, buildID := range req.GetTaskIDs() {
results = append(results, &indexpb.IndexTaskInfo{
BuildID: buildID,
State: commonpb.IndexState_IndexStateNone,
IndexFileKeys: nil,
SerializedSize: 0,
})
if info, ok := infos[buildID]; ok {
results[i].State = info.state
results[i].IndexFileKeys = info.fileKeys
results[i].SerializedSize = info.serializedSize
results[i].FailReason = info.failReason
results[i].CurrentIndexVersion = info.currentIndexVersion
results[i].IndexStoreVersion = info.indexStoreVersion
}
}
log.Debug("query index jobs result success", zap.Any("results", results))
return &indexpb.QueryJobsV2Response{
Status: merr.Success(),
ClusterID: req.GetClusterID(),
Result: &indexpb.QueryJobsV2Response_IndexJobResults{
IndexJobResults: &indexpb.IndexJobResults{
Results: results,
},
},
}, nil
case indexpb.JobType_JobTypeAnalyzeJob:
results := make([]*indexpb.AnalyzeResult, 0, len(req.GetTaskIDs()))
for _, taskID := range req.GetTaskIDs() {
info := i.getAnalyzeTaskInfo(req.GetClusterID(), taskID)
if info != nil {
results = append(results, &indexpb.AnalyzeResult{
TaskID: taskID,
State: info.state,
FailReason: info.failReason,
CentroidsFile: info.centroidsFile,
})
}
}
log.Debug("query analyze jobs result success", zap.Any("results", results))
return &indexpb.QueryJobsV2Response{
Status: merr.Success(),
ClusterID: req.GetClusterID(),
Result: &indexpb.QueryJobsV2Response_AnalyzeJobResults{
AnalyzeJobResults: &indexpb.AnalyzeResults{
Results: results,
},
},
}, nil
default:
log.Warn("IndexNode receive querying unknown type jobs")
return &indexpb.QueryJobsV2Response{
Status: merr.Status(fmt.Errorf("IndexNode receive querying unknown type jobs")),
}, nil
}
}
func (i *IndexNode) DropJobsV2(ctx context.Context, req *indexpb.DropJobsV2Request) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(zap.String("clusterID", req.GetClusterID()),
zap.Int64s("taskIDs", req.GetTaskIDs()),
zap.String("jobType", req.GetJobType().String()),
)
if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil {
log.Warn("IndexNode not ready", zap.Error(err))
return merr.Status(err), nil
}
defer i.lifetime.Done()
log.Info("IndexNode receive DropJobs request")
switch req.GetJobType() {
case indexpb.JobType_JobTypeIndexJob:
keys := make([]taskKey, 0, len(req.GetTaskIDs()))
for _, buildID := range req.GetTaskIDs() {
keys = append(keys, taskKey{ClusterID: req.GetClusterID(), BuildID: buildID})
}
infos := i.deleteIndexTaskInfos(ctx, keys)
for _, info := range infos {
if info.cancel != nil {
info.cancel()
}
}
log.Info("drop index build jobs success")
return merr.Success(), nil
case indexpb.JobType_JobTypeAnalyzeJob:
keys := make([]taskKey, 0, len(req.GetTaskIDs()))
for _, taskID := range req.GetTaskIDs() {
keys = append(keys, taskKey{ClusterID: req.GetClusterID(), BuildID: taskID})
}
infos := i.deleteAnalyzeTaskInfos(ctx, keys)
for _, info := range infos {
if info.cancel != nil {
info.cancel()
}
}
log.Info("drop analyze jobs success")
return merr.Success(), nil
default:
log.Warn("IndexNode receive dropping unknown type jobs")
return merr.Status(fmt.Errorf("IndexNode receive dropping unknown type jobs")), nil
}
}

View File

@ -21,6 +21,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -100,3 +101,132 @@ func TestMockFieldData(t *testing.T) {
chunkMgr.mockFieldData(100000, 8, 0, 0, 1)
}
type IndexNodeServiceSuite struct {
suite.Suite
cluster string
collectionID int64
partitionID int64
taskID int64
fieldID int64
segmentID int64
}
func (suite *IndexNodeServiceSuite) SetupTest() {
suite.cluster = "test_cluster"
suite.collectionID = 100
suite.partitionID = 102
suite.taskID = 11111
suite.fieldID = 103
suite.segmentID = 104
}
func (suite *IndexNodeServiceSuite) Test_AbnormalIndexNode() {
in, err := NewMockIndexNodeComponent(context.TODO())
suite.NoError(err)
suite.Nil(in.Stop())
ctx := context.TODO()
status, err := in.CreateJob(ctx, &indexpb.CreateJobRequest{})
suite.NoError(err)
suite.ErrorIs(merr.Error(status), merr.ErrServiceNotReady)
qresp, err := in.QueryJobs(ctx, &indexpb.QueryJobsRequest{})
suite.NoError(err)
suite.ErrorIs(merr.Error(qresp.GetStatus()), merr.ErrServiceNotReady)
status, err = in.DropJobs(ctx, &indexpb.DropJobsRequest{})
suite.NoError(err)
suite.ErrorIs(merr.Error(status), merr.ErrServiceNotReady)
jobNumRsp, err := in.GetJobStats(ctx, &indexpb.GetJobStatsRequest{})
suite.NoError(err)
suite.ErrorIs(merr.Error(jobNumRsp.GetStatus()), merr.ErrServiceNotReady)
metricsResp, err := in.GetMetrics(ctx, &milvuspb.GetMetricsRequest{})
err = merr.CheckRPCCall(metricsResp, err)
suite.ErrorIs(err, merr.ErrServiceNotReady)
configurationResp, err := in.ShowConfigurations(ctx, &internalpb.ShowConfigurationsRequest{})
err = merr.CheckRPCCall(configurationResp, err)
suite.ErrorIs(err, merr.ErrServiceNotReady)
status, err = in.CreateJobV2(ctx, &indexpb.CreateJobV2Request{})
err = merr.CheckRPCCall(status, err)
suite.ErrorIs(err, merr.ErrServiceNotReady)
queryAnalyzeResultResp, err := in.QueryJobsV2(ctx, &indexpb.QueryJobsV2Request{})
err = merr.CheckRPCCall(queryAnalyzeResultResp, err)
suite.ErrorIs(err, merr.ErrServiceNotReady)
dropAnalyzeTasksResp, err := in.DropJobsV2(ctx, &indexpb.DropJobsV2Request{})
err = merr.CheckRPCCall(dropAnalyzeTasksResp, err)
suite.ErrorIs(err, merr.ErrServiceNotReady)
}
func (suite *IndexNodeServiceSuite) Test_Method() {
ctx := context.TODO()
in, err := NewMockIndexNodeComponent(context.TODO())
suite.NoError(err)
suite.NoError(in.Stop())
in.UpdateStateCode(commonpb.StateCode_Healthy)
suite.Run("CreateJobV2", func() {
req := &indexpb.AnalyzeRequest{
ClusterID: suite.cluster,
TaskID: suite.taskID,
CollectionID: suite.collectionID,
PartitionID: suite.partitionID,
FieldID: suite.fieldID,
SegmentStats: map[int64]*indexpb.SegmentStats{
suite.segmentID: {
ID: suite.segmentID,
NumRows: 1024,
LogIDs: []int64{1, 2, 3},
},
},
Version: 1,
StorageConfig: nil,
}
resp, err := in.CreateJobV2(ctx, &indexpb.CreateJobV2Request{
ClusterID: suite.cluster,
TaskID: suite.taskID,
JobType: indexpb.JobType_JobTypeAnalyzeJob,
Request: &indexpb.CreateJobV2Request_AnalyzeRequest{
AnalyzeRequest: req,
},
})
err = merr.CheckRPCCall(resp, err)
suite.NoError(err)
})
suite.Run("QueryJobsV2", func() {
req := &indexpb.QueryJobsV2Request{
ClusterID: suite.cluster,
TaskIDs: []int64{suite.taskID},
JobType: indexpb.JobType_JobTypeIndexJob,
}
resp, err := in.QueryJobsV2(ctx, req)
err = merr.CheckRPCCall(resp, err)
suite.NoError(err)
})
suite.Run("DropJobsV2", func() {
req := &indexpb.DropJobsV2Request{
ClusterID: suite.cluster,
TaskIDs: []int64{suite.taskID},
JobType: indexpb.JobType_JobTypeIndexJob,
}
resp, err := in.DropJobsV2(ctx, req)
err = merr.CheckRPCCall(resp, err)
suite.NoError(err)
})
}
func Test_IndexNodeServiceSuite(t *testing.T) {
suite.Run(t, new(IndexNodeServiceSuite))
}

View File

@ -110,17 +110,17 @@ func TestIndexTaskWhenStoppingNode(t *testing.T) {
paramtable.Init()
in := NewIndexNode(ctx, factory)
in.loadOrStoreTask("cluster-1", 1, &taskInfo{
in.loadOrStoreIndexTask("cluster-1", 1, &indexTaskInfo{
state: commonpb.IndexState_InProgress,
})
in.loadOrStoreTask("cluster-2", 2, &taskInfo{
in.loadOrStoreIndexTask("cluster-2", 2, &indexTaskInfo{
state: commonpb.IndexState_Finished,
})
assert.True(t, in.hasInProgressTask())
go func() {
time.Sleep(2 * time.Second)
in.storeTaskState("cluster-1", 1, commonpb.IndexState_Finished, "")
in.storeIndexTaskState("cluster-1", 1, commonpb.IndexState_Finished, "")
}()
noTaskChan := make(chan struct{})
go func() {

View File

@ -19,28 +19,9 @@ package indexnode
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexcgopb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexcgowrapper"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
var (
@ -50,550 +31,14 @@ var (
type Blob = storage.Blob
type taskInfo struct {
cancel context.CancelFunc
state commonpb.IndexState
fileKeys []string
serializedSize uint64
failReason string
currentIndexVersion int32
indexStoreVersion int64
}
type task interface {
Ctx() context.Context
Name() string
Prepare(context.Context) error
BuildIndex(context.Context) error
SaveIndexFiles(context.Context) error
OnEnqueue(context.Context) error
SetState(state commonpb.IndexState, failReason string)
GetState() commonpb.IndexState
SetState(state indexpb.JobState, failReason string)
GetState() indexpb.JobState
PreExecute(context.Context) error
Execute(context.Context) error
PostExecute(context.Context) error
Reset()
}
type indexBuildTaskV2 struct {
*indexBuildTask
}
func newIndexBuildTaskV2(ctx context.Context,
cancel context.CancelFunc,
req *indexpb.CreateJobRequest,
node *IndexNode,
) *indexBuildTaskV2 {
t := &indexBuildTaskV2{
indexBuildTask: &indexBuildTask{
ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()),
cancel: cancel,
ctx: ctx,
req: req,
tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())),
node: node,
},
}
t.parseParams()
return t
}
func (it *indexBuildTaskV2) parseParams() {
// fill field for requests before v2.5.0
if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None {
it.req.Field = &schemapb.FieldSchema{
FieldID: it.req.GetFieldID(),
Name: it.req.GetFieldName(),
DataType: it.req.GetFieldType(),
}
}
}
func (it *indexBuildTaskV2) BuildIndex(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
indexType := it.newIndexParams[common.IndexTypeKey]
if indexType == indexparamcheck.IndexDISKANN {
// check index node support disk index
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
log.Warn("IndexNode don't support build disk index",
zap.String("index type", it.newIndexParams[common.IndexTypeKey]),
zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool()))
return merr.WrapErrIndexNotSupported("disk index")
}
// check load size and size of field data
localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize
maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat())
if usedLocalSizeWhenBuild > maxUsedLocalSize {
log.Warn("IndexNode don't has enough disk size to build disk ann index",
zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild),
zap.Int64("maxUsedLocalSize", maxUsedLocalSize))
return merr.WrapErrServiceDiskLimitExceeded(float32(usedLocalSizeWhenBuild), float32(maxUsedLocalSize))
}
err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize))
if err != nil {
log.Warn("failed to fill disk index params", zap.Error(err))
return err
}
}
storageConfig := &indexcgopb.StorageConfig{
Address: it.req.GetStorageConfig().GetAddress(),
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),
SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(),
UseSSL: it.req.GetStorageConfig().GetUseSSL(),
BucketName: it.req.GetStorageConfig().GetBucketName(),
RootPath: it.req.GetStorageConfig().GetRootPath(),
UseIAM: it.req.GetStorageConfig().GetUseIAM(),
IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(),
StorageType: it.req.GetStorageConfig().GetStorageType(),
UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(),
Region: it.req.GetStorageConfig().GetRegion(),
CloudProvider: it.req.GetStorageConfig().GetCloudProvider(),
RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(),
SslCACert: it.req.GetStorageConfig().GetSslCACert(),
}
optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields()))
for _, optField := range it.req.GetOptionalScalarFields() {
optFields = append(optFields, &indexcgopb.OptionalFieldInfo{
FieldID: optField.GetFieldID(),
FieldName: optField.GetFieldName(),
FieldType: optField.GetFieldType(),
DataPaths: optField.GetDataPaths(),
})
}
buildIndexParams := &indexcgopb.BuildIndexInfo{
ClusterID: it.req.GetClusterID(),
BuildID: it.req.GetBuildID(),
CollectionID: it.req.GetCollectionID(),
PartitionID: it.req.GetPartitionID(),
SegmentID: it.req.GetSegmentID(),
IndexVersion: it.req.GetIndexVersion(),
CurrentIndexVersion: it.req.GetCurrentIndexVersion(),
NumRows: it.req.GetNumRows(),
Dim: it.req.GetDim(),
IndexFilePrefix: it.req.GetIndexFilePrefix(),
InsertFiles: it.req.GetDataPaths(),
FieldSchema: it.req.GetField(),
StorageConfig: storageConfig,
IndexParams: mapToKVPairs(it.newIndexParams),
TypeParams: mapToKVPairs(it.newTypeParams),
StorePath: it.req.GetStorePath(),
StoreVersion: it.req.GetStoreVersion(),
IndexStorePath: it.req.GetIndexStorePath(),
OptFields: optFields,
}
var err error
it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexParams)
if err != nil {
if it.index != nil && it.index.CleanLocalData() != nil {
log.Warn("failed to clean cached data on disk after build index failed")
}
log.Warn("failed to build index", zap.Error(err))
return err
}
buildIndexLatency := it.tr.RecordSpan()
metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(buildIndexLatency.Milliseconds()))
log.Info("Successfully build index")
return nil
}
func (it *indexBuildTaskV2) SaveIndexFiles(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
gcIndex := func() {
if err := it.index.Delete(); err != nil {
log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err))
}
}
version, err := it.index.UpLoadV2()
if err != nil {
log.Warn("failed to upload index", zap.Error(err))
gcIndex()
return err
}
encodeIndexFileDur := it.tr.Record("index serialize and upload done")
metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds())
// early release index for gc, and we can ensure that Delete is idempotent.
gcIndex()
// use serialized size before encoding
var serializedSize uint64
saveFileKeys := make([]string, 0)
it.node.storeIndexFilesAndStatisticV2(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion(), version)
log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys))
saveIndexFileDur := it.tr.RecordSpan()
metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds())
it.tr.Elapse("index building all done")
log.Info("Successfully save index files")
return nil
}
// IndexBuildTask is used to record the information of the index tasks.
type indexBuildTask struct {
ident string
cancel context.CancelFunc
ctx context.Context
cm storage.ChunkManager
index indexcgowrapper.CodecIndex
req *indexpb.CreateJobRequest
newTypeParams map[string]string
newIndexParams map[string]string
tr *timerecord.TimeRecorder
queueDur time.Duration
node *IndexNode
}
func newIndexBuildTask(ctx context.Context,
cancel context.CancelFunc,
req *indexpb.CreateJobRequest,
cm storage.ChunkManager,
node *IndexNode,
) *indexBuildTask {
t := &indexBuildTask{
ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()),
cancel: cancel,
ctx: ctx,
cm: cm,
req: req,
tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())),
node: node,
}
t.parseParams()
return t
}
func (it *indexBuildTask) parseParams() {
// fill field for requests before v2.5.0
if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None {
it.req.Field = &schemapb.FieldSchema{
FieldID: it.req.GetFieldID(),
Name: it.req.GetFieldName(),
DataType: it.req.GetFieldType(),
}
}
}
func (it *indexBuildTask) Reset() {
it.ident = ""
it.cancel = nil
it.ctx = nil
it.cm = nil
it.index = nil
it.req = nil
it.newTypeParams = nil
it.newIndexParams = nil
it.tr = nil
it.node = nil
}
// Ctx is the context of index tasks.
func (it *indexBuildTask) Ctx() context.Context {
return it.ctx
}
// Name is the name of task to build index.
func (it *indexBuildTask) Name() string {
return it.ident
}
func (it *indexBuildTask) SetState(state commonpb.IndexState, failReason string) {
it.node.storeTaskState(it.req.GetClusterID(), it.req.GetBuildID(), state, failReason)
}
func (it *indexBuildTask) GetState() commonpb.IndexState {
return it.node.loadTaskState(it.req.GetClusterID(), it.req.GetBuildID())
}
// OnEnqueue enqueues indexing tasks.
func (it *indexBuildTask) OnEnqueue(ctx context.Context) error {
it.queueDur = 0
it.tr.RecordSpan()
log.Ctx(ctx).Info("IndexNode IndexBuilderTask Enqueue", zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("segmentID", it.req.GetSegmentID()))
return nil
}
func (it *indexBuildTask) Prepare(ctx context.Context) error {
it.queueDur = it.tr.RecordSpan()
log.Ctx(ctx).Info("Begin to prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("Collection", it.req.GetCollectionID()), zap.Int64("SegmentID", it.req.GetSegmentID()))
typeParams := make(map[string]string)
indexParams := make(map[string]string)
if len(it.req.DataPaths) == 0 {
for _, id := range it.req.GetDataIds() {
path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetField().GetFieldID(), id)
it.req.DataPaths = append(it.req.DataPaths, path)
}
}
if it.req.OptionalScalarFields != nil {
for _, optFields := range it.req.GetOptionalScalarFields() {
if len(optFields.DataPaths) == 0 {
for _, id := range optFields.DataIds {
path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), optFields.FieldID, id)
optFields.DataPaths = append(optFields.DataPaths, path)
}
}
}
}
// type params can be removed
for _, kvPair := range it.req.GetTypeParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
typeParams[key] = value
indexParams[key] = value
}
for _, kvPair := range it.req.GetIndexParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
// knowhere would report error if encountered the unknown key,
// so skip this
if key == common.MmapEnabledKey {
continue
}
indexParams[key] = value
}
it.newTypeParams = typeParams
it.newIndexParams = indexParams
if it.req.GetDim() == 0 {
// fill dim for requests before v2.4.0
if dimStr, ok := typeParams[common.DimKey]; ok {
var err error
it.req.Dim, err = strconv.ParseInt(dimStr, 10, 64)
if err != nil {
log.Ctx(ctx).Error("parse dimesion failed", zap.Error(err))
// ignore error
}
}
}
if it.req.GetCollectionID() == 0 {
err := it.parseFieldMetaFromBinlog(ctx)
if err != nil {
log.Ctx(ctx).Warn("parse field meta from binlog failed", zap.Error(err))
return err
}
}
log.Ctx(ctx).Info("Successfully prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collectionID", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()))
return nil
}
func (it *indexBuildTask) BuildIndex(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
indexType := it.newIndexParams[common.IndexTypeKey]
if indexType == indexparamcheck.IndexDISKANN {
// check index node support disk index
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
log.Warn("IndexNode don't support build disk index",
zap.String("index type", it.newIndexParams[common.IndexTypeKey]),
zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool()))
return errors.New("index node don't support build disk index")
}
// check load size and size of field data
localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize
maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat())
if usedLocalSizeWhenBuild > maxUsedLocalSize {
log.Warn("IndexNode don't has enough disk size to build disk ann index",
zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild),
zap.Int64("maxUsedLocalSize", maxUsedLocalSize))
return errors.New("index node don't has enough disk size to build disk ann index")
}
err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize))
if err != nil {
log.Warn("failed to fill disk index params", zap.Error(err))
return err
}
}
storageConfig := &indexcgopb.StorageConfig{
Address: it.req.GetStorageConfig().GetAddress(),
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),
SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(),
UseSSL: it.req.GetStorageConfig().GetUseSSL(),
BucketName: it.req.GetStorageConfig().GetBucketName(),
RootPath: it.req.GetStorageConfig().GetRootPath(),
UseIAM: it.req.GetStorageConfig().GetUseIAM(),
IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(),
StorageType: it.req.GetStorageConfig().GetStorageType(),
UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(),
Region: it.req.GetStorageConfig().GetRegion(),
CloudProvider: it.req.GetStorageConfig().GetCloudProvider(),
RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(),
SslCACert: it.req.GetStorageConfig().GetSslCACert(),
}
optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields()))
for _, optField := range it.req.GetOptionalScalarFields() {
optFields = append(optFields, &indexcgopb.OptionalFieldInfo{
FieldID: optField.GetFieldID(),
FieldName: optField.GetFieldName(),
FieldType: optField.GetFieldType(),
DataPaths: optField.GetDataPaths(),
})
}
buildIndexParams := &indexcgopb.BuildIndexInfo{
ClusterID: it.req.GetClusterID(),
BuildID: it.req.GetBuildID(),
CollectionID: it.req.GetCollectionID(),
PartitionID: it.req.GetPartitionID(),
SegmentID: it.req.GetSegmentID(),
IndexVersion: it.req.GetIndexVersion(),
CurrentIndexVersion: it.req.GetCurrentIndexVersion(),
NumRows: it.req.GetNumRows(),
Dim: it.req.GetDim(),
IndexFilePrefix: it.req.GetIndexFilePrefix(),
InsertFiles: it.req.GetDataPaths(),
FieldSchema: it.req.GetField(),
StorageConfig: storageConfig,
IndexParams: mapToKVPairs(it.newIndexParams),
TypeParams: mapToKVPairs(it.newTypeParams),
StorePath: it.req.GetStorePath(),
StoreVersion: it.req.GetStoreVersion(),
IndexStorePath: it.req.GetIndexStorePath(),
OptFields: optFields,
}
log.Info("debug create index", zap.Any("buildIndexParams", buildIndexParams))
var err error
it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexParams)
if err != nil {
if it.index != nil && it.index.CleanLocalData() != nil {
log.Warn("failed to clean cached data on disk after build index failed")
}
log.Warn("failed to build index", zap.Error(err))
return err
}
buildIndexLatency := it.tr.RecordSpan()
metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(buildIndexLatency.Seconds())
log.Info("Successfully build index")
return nil
}
func (it *indexBuildTask) SaveIndexFiles(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
gcIndex := func() {
if err := it.index.Delete(); err != nil {
log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err))
}
}
indexFilePath2Size, err := it.index.UpLoad()
if err != nil {
log.Warn("failed to upload index", zap.Error(err))
gcIndex()
return err
}
encodeIndexFileDur := it.tr.Record("index serialize and upload done")
metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds())
// early release index for gc, and we can ensure that Delete is idempotent.
gcIndex()
// use serialized size before encoding
var serializedSize uint64
saveFileKeys := make([]string, 0)
for filePath, fileSize := range indexFilePath2Size {
serializedSize += uint64(fileSize)
parts := strings.Split(filePath, "/")
fileKey := parts[len(parts)-1]
saveFileKeys = append(saveFileKeys, fileKey)
}
it.node.storeIndexFilesAndStatistic(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion())
log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys))
saveIndexFileDur := it.tr.RecordSpan()
metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds())
it.tr.Elapse("index building all done")
log.Info("Successfully save index files")
return nil
}
func (it *indexBuildTask) parseFieldMetaFromBinlog(ctx context.Context) error {
// fill collectionID, partitionID... for requests before v2.4.0
toLoadDataPaths := it.req.GetDataPaths()
if len(toLoadDataPaths) == 0 {
return merr.WrapErrParameterInvalidMsg("data insert path must be not empty")
}
data, err := it.cm.Read(ctx, toLoadDataPaths[0])
if err != nil {
if errors.Is(err, merr.ErrIoKeyNotFound) {
return err
}
return err
}
var insertCodec storage.InsertCodec
collectionID, partitionID, segmentID, insertData, err := insertCodec.DeserializeAll([]*Blob{{Key: toLoadDataPaths[0], Value: data}})
if err != nil {
return err
}
if len(insertData.Data) != 1 {
return merr.WrapErrParameterInvalidMsg("we expect only one field in deserialized insert data")
}
it.req.CollectionID = collectionID
it.req.PartitionID = partitionID
it.req.SegmentID = segmentID
if it.req.GetField().GetFieldID() == 0 {
for fID, value := range insertData.Data {
it.req.Field.DataType = value.GetDataType()
it.req.Field.FieldID = fID
break
}
}
it.req.CurrentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion())
return nil
}

View File

@ -0,0 +1,215 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package indexnode
import (
"context"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/clusteringpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/util/analyzecgowrapper"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/hardware"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
type analyzeTask struct {
ident string
ctx context.Context
cancel context.CancelFunc
req *indexpb.AnalyzeRequest
tr *timerecord.TimeRecorder
queueDur time.Duration
node *IndexNode
analyze analyzecgowrapper.CodecAnalyze
startTime int64
endTime int64
}
func (at *analyzeTask) Ctx() context.Context {
return at.ctx
}
func (at *analyzeTask) Name() string {
return at.ident
}
func (at *analyzeTask) PreExecute(ctx context.Context) error {
at.queueDur = at.tr.RecordSpan()
log := log.Ctx(ctx).With(zap.String("clusterID", at.req.GetClusterID()),
zap.Int64("taskID", at.req.GetTaskID()), zap.Int64("Collection", at.req.GetCollectionID()),
zap.Int64("partitionID", at.req.GetPartitionID()), zap.Int64("fieldID", at.req.GetFieldID()))
log.Info("Begin to prepare analyze task")
log.Info("Successfully prepare analyze task, nothing to do...")
return nil
}
func (at *analyzeTask) Execute(ctx context.Context) error {
var err error
log := log.Ctx(ctx).With(zap.String("clusterID", at.req.GetClusterID()),
zap.Int64("taskID", at.req.GetTaskID()), zap.Int64("Collection", at.req.GetCollectionID()),
zap.Int64("partitionID", at.req.GetPartitionID()), zap.Int64("fieldID", at.req.GetFieldID()))
log.Info("Begin to build analyze task")
if err != nil {
log.Warn("create analyze info failed", zap.Error(err))
return err
}
storageConfig := &clusteringpb.StorageConfig{
Address: at.req.GetStorageConfig().GetAddress(),
AccessKeyID: at.req.GetStorageConfig().GetAccessKeyID(),
SecretAccessKey: at.req.GetStorageConfig().GetSecretAccessKey(),
UseSSL: at.req.GetStorageConfig().GetUseSSL(),
BucketName: at.req.GetStorageConfig().GetBucketName(),
RootPath: at.req.GetStorageConfig().GetRootPath(),
UseIAM: at.req.GetStorageConfig().GetUseIAM(),
IAMEndpoint: at.req.GetStorageConfig().GetIAMEndpoint(),
StorageType: at.req.GetStorageConfig().GetStorageType(),
UseVirtualHost: at.req.GetStorageConfig().GetUseVirtualHost(),
Region: at.req.GetStorageConfig().GetRegion(),
CloudProvider: at.req.GetStorageConfig().GetCloudProvider(),
RequestTimeoutMs: at.req.GetStorageConfig().GetRequestTimeoutMs(),
SslCACert: at.req.GetStorageConfig().GetSslCACert(),
}
numRowsMap := make(map[int64]int64)
segmentInsertFilesMap := make(map[int64]*clusteringpb.InsertFiles)
for segID, stats := range at.req.GetSegmentStats() {
numRows := stats.GetNumRows()
numRowsMap[segID] = numRows
log.Info("append segment rows", zap.Int64("segment id", segID), zap.Int64("rows", numRows))
if err != nil {
log.Warn("append segment num rows failed", zap.Error(err))
return err
}
insertFiles := make([]string, 0, len(stats.GetLogIDs()))
for _, id := range stats.GetLogIDs() {
path := metautil.BuildInsertLogPath(at.req.GetStorageConfig().RootPath,
at.req.GetCollectionID(), at.req.GetPartitionID(), segID, at.req.GetFieldID(), id)
insertFiles = append(insertFiles, path)
if err != nil {
log.Warn("append insert binlog path failed", zap.Error(err))
return err
}
}
segmentInsertFilesMap[segID] = &clusteringpb.InsertFiles{InsertFiles: insertFiles}
}
field := at.req.GetField()
if field == nil || field.GetDataType() == schemapb.DataType_None {
field = &schemapb.FieldSchema{
FieldID: at.req.GetFieldID(),
Name: at.req.GetFieldName(),
DataType: at.req.GetFieldType(),
}
}
analyzeInfo := &clusteringpb.AnalyzeInfo{
ClusterID: at.req.GetClusterID(),
BuildID: at.req.GetTaskID(),
CollectionID: at.req.GetCollectionID(),
PartitionID: at.req.GetPartitionID(),
Version: at.req.GetVersion(),
Dim: at.req.GetDim(),
StorageConfig: storageConfig,
NumClusters: at.req.GetNumClusters(),
TrainSize: int64(float64(hardware.GetMemoryCount()) * at.req.GetMaxTrainSizeRatio()),
MinClusterRatio: at.req.GetMinClusterSizeRatio(),
MaxClusterRatio: at.req.GetMaxClusterSizeRatio(),
MaxClusterSize: at.req.GetMaxClusterSize(),
NumRows: numRowsMap,
InsertFiles: segmentInsertFilesMap,
FieldSchema: field,
}
at.analyze, err = analyzecgowrapper.Analyze(ctx, analyzeInfo)
if err != nil {
log.Error("failed to analyze data", zap.Error(err))
return err
}
analyzeLatency := at.tr.RecordSpan()
log.Info("analyze done", zap.Int64("analyze cost", analyzeLatency.Milliseconds()))
return nil
}
func (at *analyzeTask) PostExecute(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", at.req.GetClusterID()),
zap.Int64("taskID", at.req.GetTaskID()), zap.Int64("Collection", at.req.GetCollectionID()),
zap.Int64("partitionID", at.req.GetPartitionID()), zap.Int64("fieldID", at.req.GetFieldID()))
gc := func() {
if err := at.analyze.Delete(); err != nil {
log.Error("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err))
}
}
defer gc()
centroidsFile, _, _, _, err := at.analyze.GetResult(len(at.req.GetSegmentStats()))
if err != nil {
log.Error("failed to upload index", zap.Error(err))
return err
}
log.Info("analyze result", zap.String("centroidsFile", centroidsFile))
at.endTime = time.Now().UnixMicro()
at.node.storeAnalyzeFilesAndStatistic(at.req.GetClusterID(),
at.req.GetTaskID(),
centroidsFile)
at.tr.Elapse("index building all done")
log.Info("Successfully save analyze files")
return nil
}
func (at *analyzeTask) OnEnqueue(ctx context.Context) error {
at.queueDur = 0
at.tr.RecordSpan()
at.startTime = time.Now().UnixMicro()
log.Ctx(ctx).Info("IndexNode analyzeTask enqueued", zap.String("clusterID", at.req.GetClusterID()),
zap.Int64("taskID", at.req.GetTaskID()))
return nil
}
func (at *analyzeTask) SetState(state indexpb.JobState, failReason string) {
at.node.storeAnalyzeTaskState(at.req.GetClusterID(), at.req.GetTaskID(), state, failReason)
}
func (at *analyzeTask) GetState() indexpb.JobState {
return at.node.loadAnalyzeTaskState(at.req.GetClusterID(), at.req.GetTaskID())
}
func (at *analyzeTask) Reset() {
at.ident = ""
at.ctx = nil
at.cancel = nil
at.req = nil
at.tr = nil
at.queueDur = 0
at.node = nil
at.startTime = 0
at.endTime = 0
}

View File

@ -0,0 +1,570 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package indexnode
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexcgopb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexcgowrapper"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
type indexBuildTaskV2 struct {
*indexBuildTask
}
func newIndexBuildTaskV2(ctx context.Context,
cancel context.CancelFunc,
req *indexpb.CreateJobRequest,
node *IndexNode,
) *indexBuildTaskV2 {
t := &indexBuildTaskV2{
indexBuildTask: &indexBuildTask{
ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()),
cancel: cancel,
ctx: ctx,
req: req,
tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())),
node: node,
},
}
t.parseParams()
return t
}
func (it *indexBuildTaskV2) parseParams() {
// fill field for requests before v2.5.0
if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None {
it.req.Field = &schemapb.FieldSchema{
FieldID: it.req.GetFieldID(),
Name: it.req.GetFieldName(),
DataType: it.req.GetFieldType(),
}
}
}
func (it *indexBuildTaskV2) Execute(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
indexType := it.newIndexParams[common.IndexTypeKey]
if indexType == indexparamcheck.IndexDISKANN {
// check index node support disk index
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
log.Warn("IndexNode don't support build disk index",
zap.String("index type", it.newIndexParams[common.IndexTypeKey]),
zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool()))
return merr.WrapErrIndexNotSupported("disk index")
}
// check load size and size of field data
localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize
maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat())
if usedLocalSizeWhenBuild > maxUsedLocalSize {
log.Warn("IndexNode don't has enough disk size to build disk ann index",
zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild),
zap.Int64("maxUsedLocalSize", maxUsedLocalSize))
return merr.WrapErrServiceDiskLimitExceeded(float32(usedLocalSizeWhenBuild), float32(maxUsedLocalSize))
}
err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize))
if err != nil {
log.Warn("failed to fill disk index params", zap.Error(err))
return err
}
}
storageConfig := &indexcgopb.StorageConfig{
Address: it.req.GetStorageConfig().GetAddress(),
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),
SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(),
UseSSL: it.req.GetStorageConfig().GetUseSSL(),
BucketName: it.req.GetStorageConfig().GetBucketName(),
RootPath: it.req.GetStorageConfig().GetRootPath(),
UseIAM: it.req.GetStorageConfig().GetUseIAM(),
IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(),
StorageType: it.req.GetStorageConfig().GetStorageType(),
UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(),
Region: it.req.GetStorageConfig().GetRegion(),
CloudProvider: it.req.GetStorageConfig().GetCloudProvider(),
RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(),
SslCACert: it.req.GetStorageConfig().GetSslCACert(),
}
optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields()))
for _, optField := range it.req.GetOptionalScalarFields() {
optFields = append(optFields, &indexcgopb.OptionalFieldInfo{
FieldID: optField.GetFieldID(),
FieldName: optField.GetFieldName(),
FieldType: optField.GetFieldType(),
DataPaths: optField.GetDataPaths(),
})
}
buildIndexParams := &indexcgopb.BuildIndexInfo{
ClusterID: it.req.GetClusterID(),
BuildID: it.req.GetBuildID(),
CollectionID: it.req.GetCollectionID(),
PartitionID: it.req.GetPartitionID(),
SegmentID: it.req.GetSegmentID(),
IndexVersion: it.req.GetIndexVersion(),
CurrentIndexVersion: it.req.GetCurrentIndexVersion(),
NumRows: it.req.GetNumRows(),
Dim: it.req.GetDim(),
IndexFilePrefix: it.req.GetIndexFilePrefix(),
InsertFiles: it.req.GetDataPaths(),
FieldSchema: it.req.GetField(),
StorageConfig: storageConfig,
IndexParams: mapToKVPairs(it.newIndexParams),
TypeParams: mapToKVPairs(it.newTypeParams),
StorePath: it.req.GetStorePath(),
StoreVersion: it.req.GetStoreVersion(),
IndexStorePath: it.req.GetIndexStorePath(),
OptFields: optFields,
}
var err error
it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexParams)
if err != nil {
if it.index != nil && it.index.CleanLocalData() != nil {
log.Warn("failed to clean cached data on disk after build index failed")
}
log.Warn("failed to build index", zap.Error(err))
return err
}
buildIndexLatency := it.tr.RecordSpan()
metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(buildIndexLatency.Milliseconds()))
log.Info("Successfully build index")
return nil
}
func (it *indexBuildTaskV2) PostExecute(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
gcIndex := func() {
if err := it.index.Delete(); err != nil {
log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err))
}
}
version, err := it.index.UpLoadV2()
if err != nil {
log.Warn("failed to upload index", zap.Error(err))
gcIndex()
return err
}
encodeIndexFileDur := it.tr.Record("index serialize and upload done")
metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds())
// early release index for gc, and we can ensure that Delete is idempotent.
gcIndex()
// use serialized size before encoding
var serializedSize uint64
saveFileKeys := make([]string, 0)
it.node.storeIndexFilesAndStatisticV2(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion(), version)
log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys))
saveIndexFileDur := it.tr.RecordSpan()
metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds())
it.tr.Elapse("index building all done")
log.Info("Successfully save index files")
return nil
}
// IndexBuildTask is used to record the information of the index tasks.
type indexBuildTask struct {
ident string
cancel context.CancelFunc
ctx context.Context
cm storage.ChunkManager
index indexcgowrapper.CodecIndex
req *indexpb.CreateJobRequest
newTypeParams map[string]string
newIndexParams map[string]string
tr *timerecord.TimeRecorder
queueDur time.Duration
node *IndexNode
}
func newIndexBuildTask(ctx context.Context,
cancel context.CancelFunc,
req *indexpb.CreateJobRequest,
cm storage.ChunkManager,
node *IndexNode,
) *indexBuildTask {
t := &indexBuildTask{
ident: fmt.Sprintf("%s/%d", req.GetClusterID(), req.GetBuildID()),
cancel: cancel,
ctx: ctx,
cm: cm,
req: req,
tr: timerecord.NewTimeRecorder(fmt.Sprintf("IndexBuildID: %d, ClusterID: %s", req.GetBuildID(), req.GetClusterID())),
node: node,
}
t.parseParams()
return t
}
func (it *indexBuildTask) parseParams() {
// fill field for requests before v2.5.0
if it.req.GetField() == nil || it.req.GetField().GetDataType() == schemapb.DataType_None {
it.req.Field = &schemapb.FieldSchema{
FieldID: it.req.GetFieldID(),
Name: it.req.GetFieldName(),
DataType: it.req.GetFieldType(),
}
}
}
func (it *indexBuildTask) Reset() {
it.ident = ""
it.cancel = nil
it.ctx = nil
it.cm = nil
it.index = nil
it.req = nil
it.newTypeParams = nil
it.newIndexParams = nil
it.tr = nil
it.node = nil
}
// Ctx is the context of index tasks.
func (it *indexBuildTask) Ctx() context.Context {
return it.ctx
}
// Name is the name of task to build index.
func (it *indexBuildTask) Name() string {
return it.ident
}
func (it *indexBuildTask) SetState(state indexpb.JobState, failReason string) {
it.node.storeIndexTaskState(it.req.GetClusterID(), it.req.GetBuildID(), commonpb.IndexState(state), failReason)
}
func (it *indexBuildTask) GetState() indexpb.JobState {
return indexpb.JobState(it.node.loadIndexTaskState(it.req.GetClusterID(), it.req.GetBuildID()))
}
// OnEnqueue enqueues indexing tasks.
func (it *indexBuildTask) OnEnqueue(ctx context.Context) error {
it.queueDur = 0
it.tr.RecordSpan()
log.Ctx(ctx).Info("IndexNode IndexBuilderTask Enqueue", zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("segmentID", it.req.GetSegmentID()))
return nil
}
func (it *indexBuildTask) PreExecute(ctx context.Context) error {
it.queueDur = it.tr.RecordSpan()
log.Ctx(ctx).Info("Begin to prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("Collection", it.req.GetCollectionID()), zap.Int64("SegmentID", it.req.GetSegmentID()))
typeParams := make(map[string]string)
indexParams := make(map[string]string)
if len(it.req.DataPaths) == 0 {
for _, id := range it.req.GetDataIds() {
path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetField().GetFieldID(), id)
it.req.DataPaths = append(it.req.DataPaths, path)
}
}
if it.req.OptionalScalarFields != nil {
for _, optFields := range it.req.GetOptionalScalarFields() {
if len(optFields.DataPaths) == 0 {
for _, id := range optFields.DataIds {
path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), optFields.FieldID, id)
optFields.DataPaths = append(optFields.DataPaths, path)
}
}
}
}
// type params can be removed
for _, kvPair := range it.req.GetTypeParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
typeParams[key] = value
indexParams[key] = value
}
for _, kvPair := range it.req.GetIndexParams() {
key, value := kvPair.GetKey(), kvPair.GetValue()
// knowhere would report error if encountered the unknown key,
// so skip this
if key == common.MmapEnabledKey {
continue
}
indexParams[key] = value
}
it.newTypeParams = typeParams
it.newIndexParams = indexParams
if it.req.GetDim() == 0 {
// fill dim for requests before v2.4.0
if dimStr, ok := typeParams[common.DimKey]; ok {
var err error
it.req.Dim, err = strconv.ParseInt(dimStr, 10, 64)
if err != nil {
log.Ctx(ctx).Error("parse dimesion failed", zap.Error(err))
// ignore error
}
}
}
if it.req.GetCollectionID() == 0 {
err := it.parseFieldMetaFromBinlog(ctx)
if err != nil {
log.Ctx(ctx).Warn("parse field meta from binlog failed", zap.Error(err))
return err
}
}
log.Ctx(ctx).Info("Successfully prepare indexBuildTask", zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collectionID", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()))
return nil
}
func (it *indexBuildTask) Execute(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
indexType := it.newIndexParams[common.IndexTypeKey]
if indexType == indexparamcheck.IndexDISKANN {
// check index node support disk index
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
log.Warn("IndexNode don't support build disk index",
zap.String("index type", it.newIndexParams[common.IndexTypeKey]),
zap.Bool("enable disk", Params.IndexNodeCfg.EnableDisk.GetAsBool()))
return errors.New("index node don't support build disk index")
}
// check load size and size of field data
localUsedSize, err := indexcgowrapper.GetLocalUsedSize(paramtable.Get().LocalStorageCfg.Path.GetValue())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
}
usedLocalSizeWhenBuild := int64(float64(fieldDataSize)*diskUsageRatio) + localUsedSize
maxUsedLocalSize := int64(Params.IndexNodeCfg.DiskCapacityLimit.GetAsFloat() * Params.IndexNodeCfg.MaxDiskUsagePercentage.GetAsFloat())
if usedLocalSizeWhenBuild > maxUsedLocalSize {
log.Warn("IndexNode don't has enough disk size to build disk ann index",
zap.Int64("usedLocalSizeWhenBuild", usedLocalSizeWhenBuild),
zap.Int64("maxUsedLocalSize", maxUsedLocalSize))
return errors.New("index node don't has enough disk size to build disk ann index")
}
err = indexparams.SetDiskIndexBuildParams(it.newIndexParams, int64(fieldDataSize))
if err != nil {
log.Warn("failed to fill disk index params", zap.Error(err))
return err
}
}
storageConfig := &indexcgopb.StorageConfig{
Address: it.req.GetStorageConfig().GetAddress(),
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),
SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(),
UseSSL: it.req.GetStorageConfig().GetUseSSL(),
BucketName: it.req.GetStorageConfig().GetBucketName(),
RootPath: it.req.GetStorageConfig().GetRootPath(),
UseIAM: it.req.GetStorageConfig().GetUseIAM(),
IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(),
StorageType: it.req.GetStorageConfig().GetStorageType(),
UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(),
Region: it.req.GetStorageConfig().GetRegion(),
CloudProvider: it.req.GetStorageConfig().GetCloudProvider(),
RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(),
SslCACert: it.req.GetStorageConfig().GetSslCACert(),
}
optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields()))
for _, optField := range it.req.GetOptionalScalarFields() {
optFields = append(optFields, &indexcgopb.OptionalFieldInfo{
FieldID: optField.GetFieldID(),
FieldName: optField.GetFieldName(),
FieldType: optField.GetFieldType(),
DataPaths: optField.GetDataPaths(),
})
}
buildIndexParams := &indexcgopb.BuildIndexInfo{
ClusterID: it.req.GetClusterID(),
BuildID: it.req.GetBuildID(),
CollectionID: it.req.GetCollectionID(),
PartitionID: it.req.GetPartitionID(),
SegmentID: it.req.GetSegmentID(),
IndexVersion: it.req.GetIndexVersion(),
CurrentIndexVersion: it.req.GetCurrentIndexVersion(),
NumRows: it.req.GetNumRows(),
Dim: it.req.GetDim(),
IndexFilePrefix: it.req.GetIndexFilePrefix(),
InsertFiles: it.req.GetDataPaths(),
FieldSchema: it.req.GetField(),
StorageConfig: storageConfig,
IndexParams: mapToKVPairs(it.newIndexParams),
TypeParams: mapToKVPairs(it.newTypeParams),
StorePath: it.req.GetStorePath(),
StoreVersion: it.req.GetStoreVersion(),
IndexStorePath: it.req.GetIndexStorePath(),
OptFields: optFields,
}
log.Info("debug create index", zap.Any("buildIndexParams", buildIndexParams))
var err error
it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexParams)
if err != nil {
if it.index != nil && it.index.CleanLocalData() != nil {
log.Warn("failed to clean cached data on disk after build index failed")
}
log.Warn("failed to build index", zap.Error(err))
return err
}
buildIndexLatency := it.tr.RecordSpan()
metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(buildIndexLatency.Seconds())
log.Info("Successfully build index")
return nil
}
func (it *indexBuildTask) PostExecute(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("clusterID", it.req.GetClusterID()), zap.Int64("buildID", it.req.GetBuildID()),
zap.Int64("collection", it.req.GetCollectionID()), zap.Int64("segmentID", it.req.GetSegmentID()),
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
gcIndex := func() {
if err := it.index.Delete(); err != nil {
log.Warn("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err))
}
}
indexFilePath2Size, err := it.index.UpLoad()
if err != nil {
log.Warn("failed to upload index", zap.Error(err))
gcIndex()
return err
}
encodeIndexFileDur := it.tr.Record("index serialize and upload done")
metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds())
// early release index for gc, and we can ensure that Delete is idempotent.
gcIndex()
// use serialized size before encoding
var serializedSize uint64
saveFileKeys := make([]string, 0)
for filePath, fileSize := range indexFilePath2Size {
serializedSize += uint64(fileSize)
parts := strings.Split(filePath, "/")
fileKey := parts[len(parts)-1]
saveFileKeys = append(saveFileKeys, fileKey)
}
it.node.storeIndexFilesAndStatistic(it.req.GetClusterID(), it.req.GetBuildID(), saveFileKeys, serializedSize, it.req.GetCurrentIndexVersion())
log.Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys))
saveIndexFileDur := it.tr.RecordSpan()
metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds())
it.tr.Elapse("index building all done")
log.Info("Successfully save index files")
return nil
}
func (it *indexBuildTask) parseFieldMetaFromBinlog(ctx context.Context) error {
// fill collectionID, partitionID... for requests before v2.4.0
toLoadDataPaths := it.req.GetDataPaths()
if len(toLoadDataPaths) == 0 {
return merr.WrapErrParameterInvalidMsg("data insert path must be not empty")
}
data, err := it.cm.Read(ctx, toLoadDataPaths[0])
if err != nil {
if errors.Is(err, merr.ErrIoKeyNotFound) {
return err
}
return err
}
var insertCodec storage.InsertCodec
collectionID, partitionID, segmentID, insertData, err := insertCodec.DeserializeAll([]*Blob{{Key: toLoadDataPaths[0], Value: data}})
if err != nil {
return err
}
if len(insertData.Data) != 1 {
return merr.WrapErrParameterInvalidMsg("we expect only one field in deserialized insert data")
}
it.req.CollectionID = collectionID
it.req.PartitionID = partitionID
it.req.SegmentID = segmentID
if it.req.GetField().GetFieldID() == 0 {
for fID, value := range insertData.Data {
it.req.Field.DataType = value.GetDataType()
it.req.Field.FieldID = fID
break
}
}
it.req.CurrentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion())
return nil
}

View File

@ -26,7 +26,7 @@ import (
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr"
@ -147,7 +147,7 @@ func (queue *IndexTaskQueue) GetTaskNum() (int, int) {
atNum := 0
// remove the finished task
for _, task := range queue.activeTasks {
if task.GetState() != commonpb.IndexState_Finished && task.GetState() != commonpb.IndexState_Failed {
if task.GetState() != indexpb.JobState_JobStateFinished && task.GetState() != indexpb.JobState_JobStateFailed {
atNum++
}
}
@ -168,7 +168,7 @@ func NewIndexBuildTaskQueue(sched *TaskScheduler) *IndexTaskQueue {
// TaskScheduler is a scheduler of indexing tasks.
type TaskScheduler struct {
IndexBuildQueue TaskQueue
TaskQueue TaskQueue
buildParallel int
wg sync.WaitGroup
@ -184,7 +184,7 @@ func NewTaskScheduler(ctx context.Context) *TaskScheduler {
cancel: cancel,
buildParallel: Params.IndexNodeCfg.BuildParallel.GetAsInt(),
}
s.IndexBuildQueue = NewIndexBuildTaskQueue(s)
s.TaskQueue = NewIndexBuildTaskQueue(s)
return s
}
@ -192,7 +192,7 @@ func NewTaskScheduler(ctx context.Context) *TaskScheduler {
func (sched *TaskScheduler) scheduleIndexBuildTask() []task {
ret := make([]task, 0)
for i := 0; i < sched.buildParallel; i++ {
t := sched.IndexBuildQueue.PopUnissuedTask()
t := sched.TaskQueue.PopUnissuedTask()
if t == nil {
return ret
}
@ -201,14 +201,16 @@ func (sched *TaskScheduler) scheduleIndexBuildTask() []task {
return ret
}
func getStateFromError(err error) commonpb.IndexState {
func getStateFromError(err error) indexpb.JobState {
if errors.Is(err, errCancel) {
return commonpb.IndexState_Retry
return indexpb.JobState_JobStateRetry
} else if errors.Is(err, merr.ErrIoKeyNotFound) || errors.Is(err, merr.ErrSegcoreUnsupported) {
// NoSuchKey or unsupported error
return commonpb.IndexState_Failed
return indexpb.JobState_JobStateFailed
} else if errors.Is(err, merr.ErrSegcorePretendFinished) {
return indexpb.JobState_JobStateFinished
}
return commonpb.IndexState_Retry
return indexpb.JobState_JobStateRetry
}
func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
@ -225,10 +227,10 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
t.Reset()
debug.FreeOSMemory()
}()
sched.IndexBuildQueue.AddActiveTask(t)
defer sched.IndexBuildQueue.PopActiveTask(t.Name())
sched.TaskQueue.AddActiveTask(t)
defer sched.TaskQueue.PopActiveTask(t.Name())
log.Ctx(t.Ctx()).Debug("process task", zap.String("task", t.Name()))
pipelines := []func(context.Context) error{t.Prepare, t.BuildIndex, t.SaveIndexFiles}
pipelines := []func(context.Context) error{t.PreExecute, t.Execute, t.PostExecute}
for _, fn := range pipelines {
if err := wrap(fn); err != nil {
log.Ctx(t.Ctx()).Warn("process task failed", zap.Error(err))
@ -236,7 +238,7 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) {
return
}
}
t.SetState(commonpb.IndexState_Finished, "")
t.SetState(indexpb.JobState_JobStateFinished, "")
if indexBuildTask, ok := t.(*indexBuildTask); ok {
metrics.IndexNodeBuildIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(indexBuildTask.tr.ElapseSpan().Seconds())
metrics.IndexNodeIndexTaskLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(indexBuildTask.queueDur.Milliseconds()))
@ -250,14 +252,14 @@ func (sched *TaskScheduler) indexBuildLoop() {
select {
case <-sched.ctx.Done():
return
case <-sched.IndexBuildQueue.utChan():
case <-sched.TaskQueue.utChan():
tasks := sched.scheduleIndexBuildTask()
var wg sync.WaitGroup
for _, t := range tasks {
wg.Add(1)
go func(group *sync.WaitGroup, t task) {
defer group.Done()
sched.processTask(t, sched.IndexBuildQueue)
sched.processTask(t, sched.TaskQueue)
}(&wg, t)
}
wg.Wait()

View File

@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -72,8 +72,8 @@ type fakeTask struct {
ctx context.Context
state fakeTaskState
reterr map[fakeTaskState]error
retstate commonpb.IndexState
expectedState commonpb.IndexState
retstate indexpb.JobState
expectedState indexpb.JobState
failReason string
}
@ -94,7 +94,7 @@ func (t *fakeTask) OnEnqueue(ctx context.Context) error {
return t.reterr[t.state]
}
func (t *fakeTask) Prepare(ctx context.Context) error {
func (t *fakeTask) PreExecute(ctx context.Context) error {
t.state = fakeTaskPrepared
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
@ -106,13 +106,13 @@ func (t *fakeTask) LoadData(ctx context.Context) error {
return t.reterr[t.state]
}
func (t *fakeTask) BuildIndex(ctx context.Context) error {
func (t *fakeTask) Execute(ctx context.Context) error {
t.state = fakeTaskBuiltIndex
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
}
func (t *fakeTask) SaveIndexFiles(ctx context.Context) error {
func (t *fakeTask) PostExecute(ctx context.Context) error {
t.state = fakeTaskSavedIndexes
t.ctx.(*stagectx).setState(t.state)
return t.reterr[t.state]
@ -122,12 +122,12 @@ func (t *fakeTask) Reset() {
_taskwg.Done()
}
func (t *fakeTask) SetState(state commonpb.IndexState, failReason string) {
func (t *fakeTask) SetState(state indexpb.JobState, failReason string) {
t.retstate = state
t.failReason = failReason
}
func (t *fakeTask) GetState() commonpb.IndexState {
func (t *fakeTask) GetState() indexpb.JobState {
return t.retstate
}
@ -136,7 +136,7 @@ var (
id = 0
)
func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expectedState commonpb.IndexState) task {
func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expectedState indexpb.JobState) task {
idLock.Lock()
newID := id
id++
@ -151,7 +151,7 @@ func newTask(cancelStage fakeTaskState, reterror map[fakeTaskState]error, expect
ch: make(chan struct{}),
},
state: fakeTaskInited,
retstate: commonpb.IndexState_IndexStateNone,
retstate: indexpb.JobState_JobStateNone,
expectedState: expectedState,
}
}
@ -165,14 +165,14 @@ func TestIndexTaskScheduler(t *testing.T) {
tasks := make([]task, 0)
tasks = append(tasks,
newTask(fakeTaskEnqueued, nil, commonpb.IndexState_Retry),
newTask(fakeTaskPrepared, nil, commonpb.IndexState_Retry),
newTask(fakeTaskBuiltIndex, nil, commonpb.IndexState_Retry),
newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished),
newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskSavedIndexes: fmt.Errorf("auth failed")}, commonpb.IndexState_Retry))
newTask(fakeTaskEnqueued, nil, indexpb.JobState_JobStateRetry),
newTask(fakeTaskPrepared, nil, indexpb.JobState_JobStateRetry),
newTask(fakeTaskBuiltIndex, nil, indexpb.JobState_JobStateRetry),
newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished),
newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskSavedIndexes: fmt.Errorf("auth failed")}, indexpb.JobState_JobStateRetry))
for _, task := range tasks {
assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(task))
assert.Nil(t, scheduler.TaskQueue.Enqueue(task))
}
_taskwg.Wait()
scheduler.Close()
@ -189,11 +189,11 @@ func TestIndexTaskScheduler(t *testing.T) {
scheduler = NewTaskScheduler(context.TODO())
tasks = make([]task, 0, 1024)
for i := 0; i < 1024; i++ {
tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished))
assert.Nil(t, scheduler.IndexBuildQueue.Enqueue(tasks[len(tasks)-1]))
tasks = append(tasks, newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished))
assert.Nil(t, scheduler.TaskQueue.Enqueue(tasks[len(tasks)-1]))
}
failTask := newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished)
err := scheduler.IndexBuildQueue.Enqueue(failTask)
failTask := newTask(fakeTaskSavedIndexes, nil, indexpb.JobState_JobStateFinished)
err := scheduler.TaskQueue.Enqueue(failTask)
assert.Error(t, err)
failTask.Reset()
@ -202,6 +202,6 @@ func TestIndexTaskScheduler(t *testing.T) {
scheduler.Close()
scheduler.wg.Wait()
for _, task := range tasks {
assert.Equal(t, task.GetState(), commonpb.IndexState_Finished)
assert.Equal(t, task.GetState(), indexpb.JobState_JobStateFinished)
}
}

View File

@ -30,14 +30,115 @@ import (
milvus_storage "github.com/milvus-io/milvus-storage/go/storage"
"github.com/milvus-io/milvus-storage/go/storage/options"
"github.com/milvus-io/milvus-storage/go/storage/schema"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)
type IndexBuildTaskSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
collectionID int64
partitionID int64
segmentID int64
dataPath string
numRows int
dim int
}
func (suite *IndexBuildTaskSuite) SetupSuite() {
paramtable.Init()
suite.collectionID = 1000
suite.partitionID = 1001
suite.segmentID = 1002
suite.dataPath = "/tmp/milvus/data/1000/1001/1002/3/1"
suite.numRows = 100
suite.dim = 128
}
func (suite *IndexBuildTaskSuite) SetupTest() {
suite.schema = &schemapb.CollectionSchema{
Name: "test",
Description: "test",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "ts", DataType: schemapb.DataType_Int64},
{FieldID: 102, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
},
}
}
func (suite *IndexBuildTaskSuite) serializeData() ([]*storage.Blob, error) {
insertCodec := storage.NewInsertCodecWithSchema(&etcdpb.CollectionMeta{
Schema: suite.schema,
})
return insertCodec.Serialize(suite.partitionID, suite.segmentID, &storage.InsertData{
Data: map[storage.FieldID]storage.FieldData{
0: &storage.Int64FieldData{Data: generateLongs(suite.numRows)},
1: &storage.Int64FieldData{Data: generateLongs(suite.numRows)},
100: &storage.Int64FieldData{Data: generateLongs(suite.numRows)},
101: &storage.Int64FieldData{Data: generateLongs(suite.numRows)},
102: &storage.FloatVectorFieldData{Data: generateFloats(suite.numRows * suite.dim), Dim: suite.dim},
},
Infos: []storage.BlobInfo{{Length: suite.numRows}},
})
}
func (suite *IndexBuildTaskSuite) TestBuildMemoryIndex() {
ctx, cancel := context.WithCancel(context.Background())
req := &indexpb.CreateJobRequest{
BuildID: 1,
IndexVersion: 1,
DataPaths: []string{suite.dataPath},
IndexID: 0,
IndexName: "",
IndexParams: []*commonpb.KeyValuePair{{Key: common.IndexTypeKey, Value: "FLAT"}, {Key: common.MetricTypeKey, Value: metric.L2}},
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}},
NumRows: int64(suite.numRows),
StorageConfig: &indexpb.StorageConfig{
RootPath: "/tmp/milvus/data",
StorageType: "local",
},
CollectionID: 1,
PartitionID: 2,
SegmentID: 3,
FieldID: 102,
FieldName: "vec",
FieldType: schemapb.DataType_FloatVector,
}
cm, err := NewChunkMgrFactory().NewChunkManager(ctx, req.GetStorageConfig())
suite.NoError(err)
blobs, err := suite.serializeData()
suite.NoError(err)
err = cm.Write(ctx, suite.dataPath, blobs[0].Value)
suite.NoError(err)
t := newIndexBuildTask(ctx, cancel, req, cm, NewIndexNode(context.Background(), dependency.NewDefaultFactory(true)))
err = t.PreExecute(context.Background())
suite.NoError(err)
err = t.Execute(context.Background())
suite.NoError(err)
err = t.PostExecute(context.Background())
suite.NoError(err)
}
func TestIndexBuildTask(t *testing.T) {
suite.Run(t, new(IndexBuildTaskSuite))
}
type IndexBuildTaskV2Suite struct {
suite.Suite
schema *schemapb.CollectionSchema
@ -125,14 +226,117 @@ func (suite *IndexBuildTaskV2Suite) TestBuildIndex() {
task := newIndexBuildTaskV2(context.Background(), nil, req, NewIndexNode(context.Background(), dependency.NewDefaultFactory(true)))
var err error
err = task.Prepare(context.Background())
err = task.PreExecute(context.Background())
suite.NoError(err)
err = task.BuildIndex(context.Background())
err = task.Execute(context.Background())
suite.NoError(err)
err = task.SaveIndexFiles(context.Background())
err = task.PostExecute(context.Background())
suite.NoError(err)
}
func TestIndexBuildTaskV2Suite(t *testing.T) {
suite.Run(t, new(IndexBuildTaskV2Suite))
}
type AnalyzeTaskSuite struct {
suite.Suite
schema *schemapb.CollectionSchema
collectionID int64
partitionID int64
segmentID int64
fieldID int64
taskID int64
}
func (suite *AnalyzeTaskSuite) SetupSuite() {
paramtable.Init()
suite.collectionID = 1000
suite.partitionID = 1001
suite.segmentID = 1002
suite.fieldID = 102
suite.taskID = 1004
}
func (suite *AnalyzeTaskSuite) SetupTest() {
suite.schema = &schemapb.CollectionSchema{
Name: "test",
Description: "test",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "ts", DataType: schemapb.DataType_Int64},
{FieldID: 102, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "1"}}},
},
}
}
func (suite *AnalyzeTaskSuite) serializeData() ([]*storage.Blob, error) {
insertCodec := storage.NewInsertCodecWithSchema(&etcdpb.CollectionMeta{
Schema: suite.schema,
})
return insertCodec.Serialize(suite.partitionID, suite.segmentID, &storage.InsertData{
Data: map[storage.FieldID]storage.FieldData{
0: &storage.Int64FieldData{Data: []int64{0, 1, 2}},
1: &storage.Int64FieldData{Data: []int64{1, 2, 3}},
100: &storage.Int64FieldData{Data: []int64{0, 1, 2}},
101: &storage.Int64FieldData{Data: []int64{0, 1, 2}},
102: &storage.FloatVectorFieldData{Data: []float32{1, 2, 3}, Dim: 1},
},
Infos: []storage.BlobInfo{{Length: 3}},
})
}
func (suite *AnalyzeTaskSuite) TestAnalyze() {
ctx, cancel := context.WithCancel(context.Background())
req := &indexpb.AnalyzeRequest{
ClusterID: "test",
TaskID: 1,
CollectionID: suite.collectionID,
PartitionID: suite.partitionID,
FieldID: suite.fieldID,
FieldName: "vec",
FieldType: schemapb.DataType_FloatVector,
SegmentStats: map[int64]*indexpb.SegmentStats{
suite.segmentID: {
ID: suite.segmentID,
NumRows: 1024,
LogIDs: []int64{1},
},
},
Version: 1,
StorageConfig: &indexpb.StorageConfig{
RootPath: "/tmp/milvus/data",
StorageType: "local",
},
Dim: 1,
}
cm, err := NewChunkMgrFactory().NewChunkManager(ctx, req.GetStorageConfig())
suite.NoError(err)
blobs, err := suite.serializeData()
suite.NoError(err)
dataPath := metautil.BuildInsertLogPath(cm.RootPath(), suite.collectionID, suite.partitionID, suite.segmentID,
suite.fieldID, 1)
err = cm.Write(ctx, dataPath, blobs[0].Value)
suite.NoError(err)
t := &analyzeTask{
ident: "",
cancel: cancel,
ctx: ctx,
req: req,
tr: timerecord.NewTimeRecorder("test-indexBuildTask"),
queueDur: 0,
node: NewIndexNode(context.Background(), dependency.NewDefaultFactory(true)),
}
err = t.PreExecute(context.Background())
suite.NoError(err)
}
func TestAnalyzeTaskSuite(t *testing.T) {
suite.Run(t, new(AnalyzeTaskSuite))
}

View File

@ -7,38 +7,52 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
)
func (i *IndexNode) loadOrStoreTask(ClusterID string, buildID UniqueID, info *taskInfo) *taskInfo {
type indexTaskInfo struct {
cancel context.CancelFunc
state commonpb.IndexState
fileKeys []string
serializedSize uint64
failReason string
currentIndexVersion int32
indexStoreVersion int64
// task statistics
statistic *indexpb.JobInfo
}
func (i *IndexNode) loadOrStoreIndexTask(ClusterID string, buildID UniqueID, info *indexTaskInfo) *indexTaskInfo {
i.stateLock.Lock()
defer i.stateLock.Unlock()
key := taskKey{ClusterID: ClusterID, BuildID: buildID}
oldInfo, ok := i.tasks[key]
oldInfo, ok := i.indexTasks[key]
if ok {
return oldInfo
}
i.tasks[key] = info
i.indexTasks[key] = info
return nil
}
func (i *IndexNode) loadTaskState(ClusterID string, buildID UniqueID) commonpb.IndexState {
func (i *IndexNode) loadIndexTaskState(ClusterID string, buildID UniqueID) commonpb.IndexState {
key := taskKey{ClusterID: ClusterID, BuildID: buildID}
i.stateLock.Lock()
defer i.stateLock.Unlock()
task, ok := i.tasks[key]
task, ok := i.indexTasks[key]
if !ok {
return commonpb.IndexState_IndexStateNone
}
return task.state
}
func (i *IndexNode) storeTaskState(ClusterID string, buildID UniqueID, state commonpb.IndexState, failReason string) {
func (i *IndexNode) storeIndexTaskState(ClusterID string, buildID UniqueID, state commonpb.IndexState, failReason string) {
key := taskKey{ClusterID: ClusterID, BuildID: buildID}
i.stateLock.Lock()
defer i.stateLock.Unlock()
if task, ok := i.tasks[key]; ok {
if task, ok := i.indexTasks[key]; ok {
log.Debug("IndexNode store task state", zap.String("clusterID", ClusterID), zap.Int64("buildID", buildID),
zap.String("state", state.String()), zap.String("fail reason", failReason))
task.state = state
@ -46,10 +60,10 @@ func (i *IndexNode) storeTaskState(ClusterID string, buildID UniqueID, state com
}
}
func (i *IndexNode) foreachTaskInfo(fn func(ClusterID string, buildID UniqueID, info *taskInfo)) {
func (i *IndexNode) foreachIndexTaskInfo(fn func(ClusterID string, buildID UniqueID, info *indexTaskInfo)) {
i.stateLock.Lock()
defer i.stateLock.Unlock()
for key, info := range i.tasks {
for key, info := range i.indexTasks {
fn(key.ClusterID, key.BuildID, info)
}
}
@ -64,7 +78,7 @@ func (i *IndexNode) storeIndexFilesAndStatistic(
key := taskKey{ClusterID: ClusterID, BuildID: buildID}
i.stateLock.Lock()
defer i.stateLock.Unlock()
if info, ok := i.tasks[key]; ok {
if info, ok := i.indexTasks[key]; ok {
info.fileKeys = common.CloneStringList(fileKeys)
info.serializedSize = serializedSize
info.currentIndexVersion = currentIndexVersion
@ -83,7 +97,7 @@ func (i *IndexNode) storeIndexFilesAndStatisticV2(
key := taskKey{ClusterID: ClusterID, BuildID: buildID}
i.stateLock.Lock()
defer i.stateLock.Unlock()
if info, ok := i.tasks[key]; ok {
if info, ok := i.indexTasks[key]; ok {
info.fileKeys = common.CloneStringList(fileKeys)
info.serializedSize = serializedSize
info.currentIndexVersion = currentIndexVersion
@ -92,15 +106,15 @@ func (i *IndexNode) storeIndexFilesAndStatisticV2(
}
}
func (i *IndexNode) deleteTaskInfos(ctx context.Context, keys []taskKey) []*taskInfo {
func (i *IndexNode) deleteIndexTaskInfos(ctx context.Context, keys []taskKey) []*indexTaskInfo {
i.stateLock.Lock()
defer i.stateLock.Unlock()
deleted := make([]*taskInfo, 0, len(keys))
deleted := make([]*indexTaskInfo, 0, len(keys))
for _, key := range keys {
info, ok := i.tasks[key]
info, ok := i.indexTasks[key]
if ok {
deleted = append(deleted, info)
delete(i.tasks, key)
delete(i.indexTasks, key)
log.Ctx(ctx).Info("delete task infos",
zap.String("cluster_id", key.ClusterID), zap.Int64("build_id", key.BuildID))
}
@ -108,13 +122,113 @@ func (i *IndexNode) deleteTaskInfos(ctx context.Context, keys []taskKey) []*task
return deleted
}
func (i *IndexNode) deleteAllTasks() []*taskInfo {
func (i *IndexNode) deleteAllIndexTasks() []*indexTaskInfo {
i.stateLock.Lock()
deletedTasks := i.tasks
i.tasks = make(map[taskKey]*taskInfo)
deletedTasks := i.indexTasks
i.indexTasks = make(map[taskKey]*indexTaskInfo)
i.stateLock.Unlock()
deleted := make([]*taskInfo, 0, len(deletedTasks))
deleted := make([]*indexTaskInfo, 0, len(deletedTasks))
for _, info := range deletedTasks {
deleted = append(deleted, info)
}
return deleted
}
type analyzeTaskInfo struct {
cancel context.CancelFunc
state indexpb.JobState
failReason string
centroidsFile string
}
func (i *IndexNode) loadOrStoreAnalyzeTask(clusterID string, taskID UniqueID, info *analyzeTaskInfo) *analyzeTaskInfo {
i.stateLock.Lock()
defer i.stateLock.Unlock()
key := taskKey{ClusterID: clusterID, BuildID: taskID}
oldInfo, ok := i.analyzeTasks[key]
if ok {
return oldInfo
}
i.analyzeTasks[key] = info
return nil
}
func (i *IndexNode) loadAnalyzeTaskState(clusterID string, taskID UniqueID) indexpb.JobState {
key := taskKey{ClusterID: clusterID, BuildID: taskID}
i.stateLock.Lock()
defer i.stateLock.Unlock()
task, ok := i.analyzeTasks[key]
if !ok {
return indexpb.JobState_JobStateNone
}
return task.state
}
func (i *IndexNode) storeAnalyzeTaskState(clusterID string, taskID UniqueID, state indexpb.JobState, failReason string) {
key := taskKey{ClusterID: clusterID, BuildID: taskID}
i.stateLock.Lock()
defer i.stateLock.Unlock()
if task, ok := i.analyzeTasks[key]; ok {
log.Info("IndexNode store analyze task state", zap.String("clusterID", clusterID), zap.Int64("taskID", taskID),
zap.String("state", state.String()), zap.String("fail reason", failReason))
task.state = state
task.failReason = failReason
}
}
func (i *IndexNode) foreachAnalyzeTaskInfo(fn func(clusterID string, taskID UniqueID, info *analyzeTaskInfo)) {
i.stateLock.Lock()
defer i.stateLock.Unlock()
for key, info := range i.analyzeTasks {
fn(key.ClusterID, key.BuildID, info)
}
}
func (i *IndexNode) storeAnalyzeFilesAndStatistic(
ClusterID string,
taskID UniqueID,
centroidsFile string,
) {
key := taskKey{ClusterID: ClusterID, BuildID: taskID}
i.stateLock.Lock()
defer i.stateLock.Unlock()
if info, ok := i.analyzeTasks[key]; ok {
info.centroidsFile = centroidsFile
return
}
}
func (i *IndexNode) getAnalyzeTaskInfo(clusterID string, taskID UniqueID) *analyzeTaskInfo {
i.stateLock.Lock()
defer i.stateLock.Unlock()
return i.analyzeTasks[taskKey{ClusterID: clusterID, BuildID: taskID}]
}
func (i *IndexNode) deleteAnalyzeTaskInfos(ctx context.Context, keys []taskKey) []*analyzeTaskInfo {
i.stateLock.Lock()
defer i.stateLock.Unlock()
deleted := make([]*analyzeTaskInfo, 0, len(keys))
for _, key := range keys {
info, ok := i.analyzeTasks[key]
if ok {
deleted = append(deleted, info)
delete(i.analyzeTasks, key)
log.Ctx(ctx).Info("delete analyze task infos",
zap.String("clusterID", key.ClusterID), zap.Int64("taskID", key.BuildID))
}
}
return deleted
}
func (i *IndexNode) deleteAllAnalyzeTasks() []*analyzeTaskInfo {
i.stateLock.Lock()
deletedTasks := i.analyzeTasks
i.analyzeTasks = make(map[taskKey]*analyzeTaskInfo)
i.stateLock.Unlock()
deleted := make([]*analyzeTaskInfo, 0, len(deletedTasks))
for _, info := range deletedTasks {
deleted = append(deleted, info)
}
@ -124,11 +238,17 @@ func (i *IndexNode) deleteAllTasks() []*taskInfo {
func (i *IndexNode) hasInProgressTask() bool {
i.stateLock.Lock()
defer i.stateLock.Unlock()
for _, info := range i.tasks {
for _, info := range i.indexTasks {
if info.state == commonpb.IndexState_InProgress {
return true
}
}
for _, info := range i.analyzeTasks {
if info.state == indexpb.JobState_JobStateInProgress {
return true
}
}
return false
}
@ -151,11 +271,16 @@ func (i *IndexNode) waitTaskFinish() {
}
case <-timeoutCtx.Done():
log.Warn("timeout, the index node has some progress task")
for _, info := range i.tasks {
for _, info := range i.indexTasks {
if info.state == commonpb.IndexState_InProgress {
log.Warn("progress task", zap.Any("info", info))
}
}
for _, info := range i.analyzeTasks {
if info.state == indexpb.JobState_JobStateInProgress {
log.Warn("progress task", zap.Any("info", info))
}
}
return
}
}

View File

@ -17,6 +17,7 @@
package indexnode
import (
"math/rand"
"testing"
"github.com/stretchr/testify/suite"
@ -39,3 +40,19 @@ func (s *utilSuite) Test_mapToKVPairs() {
func Test_utilSuite(t *testing.T) {
suite.Run(t, new(utilSuite))
}
func generateFloats(num int) []float32 {
data := make([]float32, num)
for i := 0; i < num; i++ {
data[i] = rand.Float32()
}
return data
}
func generateLongs(num int) []int64 {
data := make([]int64, num)
for i := 0; i < num; i++ {
data[i] = rand.Int63()
}
return data
}

View File

@ -7,6 +7,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -152,6 +153,10 @@ type DataCoordCatalog interface {
ListCompactionTask(ctx context.Context) ([]*datapb.CompactionTask, error)
SaveCompactionTask(ctx context.Context, task *datapb.CompactionTask) error
DropCompactionTask(ctx context.Context, task *datapb.CompactionTask) error
ListAnalyzeTasks(ctx context.Context) ([]*indexpb.AnalyzeTask, error)
SaveAnalyzeTask(ctx context.Context, task *indexpb.AnalyzeTask) error
DropAnalyzeTask(ctx context.Context, taskID typeutil.UniqueID) error
}
type QueryCoordCatalog interface {

View File

@ -28,6 +28,7 @@ const (
ImportTaskPrefix = MetaPrefix + "/import-task"
PreImportTaskPrefix = MetaPrefix + "/preimport-task"
CompactionTaskPrefix = MetaPrefix + "/compaction-task"
AnalyzeTaskPrefix = MetaPrefix + "/analyze-task"
NonRemoveFlagTomestone = "non-removed"
RemoveFlagTomestone = "removed"

View File

@ -834,3 +834,41 @@ func (kc *Catalog) DropCompactionTask(ctx context.Context, task *datapb.Compacti
key := buildCompactionTaskPath(task)
return kc.MetaKv.Remove(key)
}
func (kc *Catalog) ListAnalyzeTasks(ctx context.Context) ([]*indexpb.AnalyzeTask, error) {
tasks := make([]*indexpb.AnalyzeTask, 0)
_, values, err := kc.MetaKv.LoadWithPrefix(AnalyzeTaskPrefix)
if err != nil {
return nil, err
}
for _, value := range values {
task := &indexpb.AnalyzeTask{}
err = proto.Unmarshal([]byte(value), task)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}
return tasks, nil
}
func (kc *Catalog) SaveAnalyzeTask(ctx context.Context, task *indexpb.AnalyzeTask) error {
key := buildAnalyzeTaskKey(task.TaskID)
value, err := proto.Marshal(task)
if err != nil {
return err
}
err = kc.MetaKv.Save(key, string(value))
if err != nil {
return err
}
return nil
}
func (kc *Catalog) DropAnalyzeTask(ctx context.Context, taskID typeutil.UniqueID) error {
key := buildAnalyzeTaskKey(taskID)
return kc.MetaKv.Remove(key)
}

View File

@ -325,3 +325,7 @@ func buildImportTaskKey(taskID int64) string {
func buildPreImportTaskKey(taskID int64) string {
return fmt.Sprintf("%s/%d", PreImportTaskPrefix, taskID)
}
func buildAnalyzeTaskKey(taskID int64) string {
return fmt.Sprintf("%s/%d", AnalyzeTaskPrefix, taskID)
}

View File

@ -6,6 +6,7 @@ import (
context "context"
datapb "github.com/milvus-io/milvus/internal/proto/datapb"
indexpb "github.com/milvus-io/milvus/internal/proto/indexpb"
metastore "github.com/milvus-io/milvus/internal/metastore"
@ -345,6 +346,49 @@ func (_c *DataCoordCatalog_CreateSegmentIndex_Call) RunAndReturn(run func(contex
return _c
}
// DropAnalyzeTask provides a mock function with given fields: ctx, taskID
func (_m *DataCoordCatalog) DropAnalyzeTask(ctx context.Context, taskID int64) error {
ret := _m.Called(ctx, taskID)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = rf(ctx, taskID)
} else {
r0 = ret.Error(0)
}
return r0
}
// DataCoordCatalog_DropAnalyzeTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropAnalyzeTask'
type DataCoordCatalog_DropAnalyzeTask_Call struct {
*mock.Call
}
// DropAnalyzeTask is a helper method to define mock.On call
// - ctx context.Context
// - taskID int64
func (_e *DataCoordCatalog_Expecter) DropAnalyzeTask(ctx interface{}, taskID interface{}) *DataCoordCatalog_DropAnalyzeTask_Call {
return &DataCoordCatalog_DropAnalyzeTask_Call{Call: _e.mock.On("DropAnalyzeTask", ctx, taskID)}
}
func (_c *DataCoordCatalog_DropAnalyzeTask_Call) Run(run func(ctx context.Context, taskID int64)) *DataCoordCatalog_DropAnalyzeTask_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *DataCoordCatalog_DropAnalyzeTask_Call) Return(_a0 error) *DataCoordCatalog_DropAnalyzeTask_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *DataCoordCatalog_DropAnalyzeTask_Call) RunAndReturn(run func(context.Context, int64) error) *DataCoordCatalog_DropAnalyzeTask_Call {
_c.Call.Return(run)
return _c
}
// DropChannel provides a mock function with given fields: ctx, channel
func (_m *DataCoordCatalog) DropChannel(ctx context.Context, channel string) error {
ret := _m.Called(ctx, channel)
@ -777,6 +821,60 @@ func (_c *DataCoordCatalog_GcConfirm_Call) RunAndReturn(run func(context.Context
return _c
}
// ListAnalyzeTasks provides a mock function with given fields: ctx
func (_m *DataCoordCatalog) ListAnalyzeTasks(ctx context.Context) ([]*indexpb.AnalyzeTask, error) {
ret := _m.Called(ctx)
var r0 []*indexpb.AnalyzeTask
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) ([]*indexpb.AnalyzeTask, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) []*indexpb.AnalyzeTask); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*indexpb.AnalyzeTask)
}
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DataCoordCatalog_ListAnalyzeTasks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAnalyzeTasks'
type DataCoordCatalog_ListAnalyzeTasks_Call struct {
*mock.Call
}
// ListAnalyzeTasks is a helper method to define mock.On call
// - ctx context.Context
func (_e *DataCoordCatalog_Expecter) ListAnalyzeTasks(ctx interface{}) *DataCoordCatalog_ListAnalyzeTasks_Call {
return &DataCoordCatalog_ListAnalyzeTasks_Call{Call: _e.mock.On("ListAnalyzeTasks", ctx)}
}
func (_c *DataCoordCatalog_ListAnalyzeTasks_Call) Run(run func(ctx context.Context)) *DataCoordCatalog_ListAnalyzeTasks_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *DataCoordCatalog_ListAnalyzeTasks_Call) Return(_a0 []*indexpb.AnalyzeTask, _a1 error) *DataCoordCatalog_ListAnalyzeTasks_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *DataCoordCatalog_ListAnalyzeTasks_Call) RunAndReturn(run func(context.Context) ([]*indexpb.AnalyzeTask, error)) *DataCoordCatalog_ListAnalyzeTasks_Call {
_c.Call.Return(run)
return _c
}
// ListChannelCheckpoint provides a mock function with given fields: ctx
func (_m *DataCoordCatalog) ListChannelCheckpoint(ctx context.Context) (map[string]*msgpb.MsgPosition, error) {
ret := _m.Called(ctx)
@ -1292,6 +1390,49 @@ func (_c *DataCoordCatalog_MarkChannelDeleted_Call) RunAndReturn(run func(contex
return _c
}
// SaveAnalyzeTask provides a mock function with given fields: ctx, task
func (_m *DataCoordCatalog) SaveAnalyzeTask(ctx context.Context, task *indexpb.AnalyzeTask) error {
ret := _m.Called(ctx, task)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.AnalyzeTask) error); ok {
r0 = rf(ctx, task)
} else {
r0 = ret.Error(0)
}
return r0
}
// DataCoordCatalog_SaveAnalyzeTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveAnalyzeTask'
type DataCoordCatalog_SaveAnalyzeTask_Call struct {
*mock.Call
}
// SaveAnalyzeTask is a helper method to define mock.On call
// - ctx context.Context
// - task *indexpb.AnalyzeTask
func (_e *DataCoordCatalog_Expecter) SaveAnalyzeTask(ctx interface{}, task interface{}) *DataCoordCatalog_SaveAnalyzeTask_Call {
return &DataCoordCatalog_SaveAnalyzeTask_Call{Call: _e.mock.On("SaveAnalyzeTask", ctx, task)}
}
func (_c *DataCoordCatalog_SaveAnalyzeTask_Call) Run(run func(ctx context.Context, task *indexpb.AnalyzeTask)) *DataCoordCatalog_SaveAnalyzeTask_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*indexpb.AnalyzeTask))
})
return _c
}
func (_c *DataCoordCatalog_SaveAnalyzeTask_Call) Return(_a0 error) *DataCoordCatalog_SaveAnalyzeTask_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *DataCoordCatalog_SaveAnalyzeTask_Call) RunAndReturn(run func(context.Context, *indexpb.AnalyzeTask) error) *DataCoordCatalog_SaveAnalyzeTask_Call {
_c.Call.Return(run)
return _c
}
// SaveChannelCheckpoint provides a mock function with given fields: ctx, vChannel, pos
func (_m *DataCoordCatalog) SaveChannelCheckpoint(ctx context.Context, vChannel string, pos *msgpb.MsgPosition) error {
ret := _m.Called(ctx, vChannel, pos)

View File

@ -85,6 +85,61 @@ func (_c *MockIndexNode_CreateJob_Call) RunAndReturn(run func(context.Context, *
return _c
}
// CreateJobV2 provides a mock function with given fields: _a0, _a1
func (_m *MockIndexNode) CreateJobV2(_a0 context.Context, _a1 *indexpb.CreateJobV2Request) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *indexpb.CreateJobV2Request) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockIndexNode_CreateJobV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateJobV2'
type MockIndexNode_CreateJobV2_Call struct {
*mock.Call
}
// CreateJobV2 is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *indexpb.CreateJobV2Request
func (_e *MockIndexNode_Expecter) CreateJobV2(_a0 interface{}, _a1 interface{}) *MockIndexNode_CreateJobV2_Call {
return &MockIndexNode_CreateJobV2_Call{Call: _e.mock.On("CreateJobV2", _a0, _a1)}
}
func (_c *MockIndexNode_CreateJobV2_Call) Run(run func(_a0 context.Context, _a1 *indexpb.CreateJobV2Request)) *MockIndexNode_CreateJobV2_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*indexpb.CreateJobV2Request))
})
return _c
}
func (_c *MockIndexNode_CreateJobV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNode_CreateJobV2_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockIndexNode_CreateJobV2_Call) RunAndReturn(run func(context.Context, *indexpb.CreateJobV2Request) (*commonpb.Status, error)) *MockIndexNode_CreateJobV2_Call {
_c.Call.Return(run)
return _c
}
// DropJobs provides a mock function with given fields: _a0, _a1
func (_m *MockIndexNode) DropJobs(_a0 context.Context, _a1 *indexpb.DropJobsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
@ -140,6 +195,61 @@ func (_c *MockIndexNode_DropJobs_Call) RunAndReturn(run func(context.Context, *i
return _c
}
// DropJobsV2 provides a mock function with given fields: _a0, _a1
func (_m *MockIndexNode) DropJobsV2(_a0 context.Context, _a1 *indexpb.DropJobsV2Request) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request) (*commonpb.Status, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DropJobsV2Request) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockIndexNode_DropJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropJobsV2'
type MockIndexNode_DropJobsV2_Call struct {
*mock.Call
}
// DropJobsV2 is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *indexpb.DropJobsV2Request
func (_e *MockIndexNode_Expecter) DropJobsV2(_a0 interface{}, _a1 interface{}) *MockIndexNode_DropJobsV2_Call {
return &MockIndexNode_DropJobsV2_Call{Call: _e.mock.On("DropJobsV2", _a0, _a1)}
}
func (_c *MockIndexNode_DropJobsV2_Call) Run(run func(_a0 context.Context, _a1 *indexpb.DropJobsV2Request)) *MockIndexNode_DropJobsV2_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*indexpb.DropJobsV2Request))
})
return _c
}
func (_c *MockIndexNode_DropJobsV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNode_DropJobsV2_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockIndexNode_DropJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.DropJobsV2Request) (*commonpb.Status, error)) *MockIndexNode_DropJobsV2_Call {
_c.Call.Return(run)
return _c
}
// GetAddress provides a mock function with given fields:
func (_m *MockIndexNode) GetAddress() string {
ret := _m.Called()
@ -497,6 +607,61 @@ func (_c *MockIndexNode_QueryJobs_Call) RunAndReturn(run func(context.Context, *
return _c
}
// QueryJobsV2 provides a mock function with given fields: _a0, _a1
func (_m *MockIndexNode) QueryJobsV2(_a0 context.Context, _a1 *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error) {
ret := _m.Called(_a0, _a1)
var r0 *indexpb.QueryJobsV2Response
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error)); ok {
return rf(_a0, _a1)
}
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request) *indexpb.QueryJobsV2Response); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*indexpb.QueryJobsV2Response)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *indexpb.QueryJobsV2Request) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockIndexNode_QueryJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryJobsV2'
type MockIndexNode_QueryJobsV2_Call struct {
*mock.Call
}
// QueryJobsV2 is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *indexpb.QueryJobsV2Request
func (_e *MockIndexNode_Expecter) QueryJobsV2(_a0 interface{}, _a1 interface{}) *MockIndexNode_QueryJobsV2_Call {
return &MockIndexNode_QueryJobsV2_Call{Call: _e.mock.On("QueryJobsV2", _a0, _a1)}
}
func (_c *MockIndexNode_QueryJobsV2_Call) Run(run func(_a0 context.Context, _a1 *indexpb.QueryJobsV2Request)) *MockIndexNode_QueryJobsV2_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*indexpb.QueryJobsV2Request))
})
return _c
}
func (_c *MockIndexNode_QueryJobsV2_Call) Return(_a0 *indexpb.QueryJobsV2Response, _a1 error) *MockIndexNode_QueryJobsV2_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockIndexNode_QueryJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.QueryJobsV2Request) (*indexpb.QueryJobsV2Response, error)) *MockIndexNode_QueryJobsV2_Call {
_c.Call.Return(run)
return _c
}
// Register provides a mock function with given fields:
func (_m *MockIndexNode) Register() error {
ret := _m.Called()

View File

@ -142,6 +142,76 @@ func (_c *MockIndexNodeClient_CreateJob_Call) RunAndReturn(run func(context.Cont
return _c
}
// CreateJobV2 provides a mock function with given fields: ctx, in, opts
func (_m *MockIndexNodeClient) CreateJobV2(ctx context.Context, in *indexpb.CreateJobV2Request, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockIndexNodeClient_CreateJobV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateJobV2'
type MockIndexNodeClient_CreateJobV2_Call struct {
*mock.Call
}
// CreateJobV2 is a helper method to define mock.On call
// - ctx context.Context
// - in *indexpb.CreateJobV2Request
// - opts ...grpc.CallOption
func (_e *MockIndexNodeClient_Expecter) CreateJobV2(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_CreateJobV2_Call {
return &MockIndexNodeClient_CreateJobV2_Call{Call: _e.mock.On("CreateJobV2",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockIndexNodeClient_CreateJobV2_Call) Run(run func(ctx context.Context, in *indexpb.CreateJobV2Request, opts ...grpc.CallOption)) *MockIndexNodeClient_CreateJobV2_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*indexpb.CreateJobV2Request), variadicArgs...)
})
return _c
}
func (_c *MockIndexNodeClient_CreateJobV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNodeClient_CreateJobV2_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockIndexNodeClient_CreateJobV2_Call) RunAndReturn(run func(context.Context, *indexpb.CreateJobV2Request, ...grpc.CallOption) (*commonpb.Status, error)) *MockIndexNodeClient_CreateJobV2_Call {
_c.Call.Return(run)
return _c
}
// DropJobs provides a mock function with given fields: ctx, in, opts
func (_m *MockIndexNodeClient) DropJobs(ctx context.Context, in *indexpb.DropJobsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
@ -212,6 +282,76 @@ func (_c *MockIndexNodeClient_DropJobs_Call) RunAndReturn(run func(context.Conte
return _c
}
// DropJobsV2 provides a mock function with given fields: ctx, in, opts
func (_m *MockIndexNodeClient) DropJobsV2(ctx context.Context, in *indexpb.DropJobsV2Request, opts ...grpc.CallOption) (*commonpb.Status, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *commonpb.Status
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) (*commonpb.Status, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) *commonpb.Status); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockIndexNodeClient_DropJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropJobsV2'
type MockIndexNodeClient_DropJobsV2_Call struct {
*mock.Call
}
// DropJobsV2 is a helper method to define mock.On call
// - ctx context.Context
// - in *indexpb.DropJobsV2Request
// - opts ...grpc.CallOption
func (_e *MockIndexNodeClient_Expecter) DropJobsV2(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_DropJobsV2_Call {
return &MockIndexNodeClient_DropJobsV2_Call{Call: _e.mock.On("DropJobsV2",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockIndexNodeClient_DropJobsV2_Call) Run(run func(ctx context.Context, in *indexpb.DropJobsV2Request, opts ...grpc.CallOption)) *MockIndexNodeClient_DropJobsV2_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*indexpb.DropJobsV2Request), variadicArgs...)
})
return _c
}
func (_c *MockIndexNodeClient_DropJobsV2_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNodeClient_DropJobsV2_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockIndexNodeClient_DropJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.DropJobsV2Request, ...grpc.CallOption) (*commonpb.Status, error)) *MockIndexNodeClient_DropJobsV2_Call {
_c.Call.Return(run)
return _c
}
// GetComponentStates provides a mock function with given fields: ctx, in, opts
func (_m *MockIndexNodeClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) {
_va := make([]interface{}, len(opts))
@ -562,6 +702,76 @@ func (_c *MockIndexNodeClient_QueryJobs_Call) RunAndReturn(run func(context.Cont
return _c
}
// QueryJobsV2 provides a mock function with given fields: ctx, in, opts
func (_m *MockIndexNodeClient) QueryJobsV2(ctx context.Context, in *indexpb.QueryJobsV2Request, opts ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) {
_va := make([]interface{}, len(opts))
for _i := range opts {
_va[_i] = opts[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, in)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *indexpb.QueryJobsV2Response
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error)); ok {
return rf(ctx, in, opts...)
}
if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) *indexpb.QueryJobsV2Response); ok {
r0 = rf(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*indexpb.QueryJobsV2Response)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) error); ok {
r1 = rf(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockIndexNodeClient_QueryJobsV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryJobsV2'
type MockIndexNodeClient_QueryJobsV2_Call struct {
*mock.Call
}
// QueryJobsV2 is a helper method to define mock.On call
// - ctx context.Context
// - in *indexpb.QueryJobsV2Request
// - opts ...grpc.CallOption
func (_e *MockIndexNodeClient_Expecter) QueryJobsV2(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_QueryJobsV2_Call {
return &MockIndexNodeClient_QueryJobsV2_Call{Call: _e.mock.On("QueryJobsV2",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *MockIndexNodeClient_QueryJobsV2_Call) Run(run func(ctx context.Context, in *indexpb.QueryJobsV2Request, opts ...grpc.CallOption)) *MockIndexNodeClient_QueryJobsV2_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]grpc.CallOption, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(grpc.CallOption)
}
}
run(args[0].(context.Context), args[1].(*indexpb.QueryJobsV2Request), variadicArgs...)
})
return _c
}
func (_c *MockIndexNodeClient_QueryJobsV2_Call) Return(_a0 *indexpb.QueryJobsV2Response, _a1 error) *MockIndexNodeClient_QueryJobsV2_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockIndexNodeClient_QueryJobsV2_Call) RunAndReturn(run func(context.Context, *indexpb.QueryJobsV2Request, ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error)) *MockIndexNodeClient_QueryJobsV2_Call {
_c.Call.Return(run)
return _c
}
// ShowConfigurations provides a mock function with given fields: ctx, in, opts
func (_m *MockIndexNodeClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) {
_va := make([]interface{}, len(opts))

View File

@ -0,0 +1,55 @@
syntax = "proto3";
package milvus.proto.clustering;
option go_package = "github.com/milvus-io/milvus/internal/proto/clusteringpb";
import "schema.proto";
// Synchronously modify StorageConfig in index_coord.proto/index_cgo_msg.proto file
message StorageConfig {
string address = 1;
string access_keyID = 2;
string secret_access_key = 3;
bool useSSL = 4;
string bucket_name = 5;
string root_path = 6;
bool useIAM = 7;
string IAMEndpoint = 8;
string storage_type = 9;
bool use_virtual_host = 10;
string region = 11;
string cloud_provider = 12;
int64 request_timeout_ms = 13;
string sslCACert = 14;
}
message InsertFiles {
repeated string insert_files = 1;
}
message AnalyzeInfo {
string clusterID = 1;
int64 buildID = 2;
int64 collectionID = 3;
int64 partitionID = 4;
int64 segmentID = 5;
int64 version = 6;
int64 dim = 7;
int64 num_clusters = 8;
int64 train_size = 9;
double min_cluster_ratio = 10; // min_cluster_size / avg_cluster_size < min_cluster_ratio, is skew
double max_cluster_ratio = 11; // max_cluster_size / avg_cluster_size > max_cluster_ratio, is skew
int64 max_cluster_size = 12;
map<int64, InsertFiles> insert_files = 13;
map<int64, int64> num_rows = 14;
schema.FieldSchema field_schema = 15;
StorageConfig storage_config = 16;
}
message ClusteringCentroidsStats {
repeated schema.VectorField centroids = 1;
}
message ClusteringCentroidIdMappingStats {
repeated uint32 centroid_id_mapping = 1;
repeated int64 num_in_centroid = 2;
}

View File

@ -10,19 +10,19 @@ import "milvus.proto";
import "schema.proto";
service IndexCoord {
rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {}
rpc GetStatisticsChannel(internal.GetStatisticsChannelRequest) returns(milvus.StringResponse){}
rpc CreateIndex(CreateIndexRequest) returns (common.Status){}
rpc AlterIndex(AlterIndexRequest) returns (common.Status){}
// Deprecated: use DescribeIndex instead
rpc GetIndexState(GetIndexStateRequest) returns (GetIndexStateResponse) {}
rpc GetSegmentIndexState(GetSegmentIndexStateRequest) returns (GetSegmentIndexStateResponse) {}
rpc GetIndexInfos(GetIndexInfoRequest) returns (GetIndexInfoResponse){}
rpc DropIndex(DropIndexRequest) returns (common.Status) {}
rpc DescribeIndex(DescribeIndexRequest) returns (DescribeIndexResponse) {}
rpc GetIndexStatistics(GetIndexStatisticsRequest) returns (GetIndexStatisticsResponse) {}
// Deprecated: use DescribeIndex instead
rpc GetIndexBuildProgress(GetIndexBuildProgressRequest) returns (GetIndexBuildProgressResponse) {}
rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {}
rpc GetStatisticsChannel(internal.GetStatisticsChannelRequest) returns(milvus.StringResponse){}
rpc CreateIndex(CreateIndexRequest) returns (common.Status){}
rpc AlterIndex(AlterIndexRequest) returns (common.Status){}
// Deprecated: use DescribeIndex instead
rpc GetIndexState(GetIndexStateRequest) returns (GetIndexStateResponse) {}
rpc GetSegmentIndexState(GetSegmentIndexStateRequest) returns (GetSegmentIndexStateResponse) {}
rpc GetIndexInfos(GetIndexInfoRequest) returns (GetIndexInfoResponse){}
rpc DropIndex(DropIndexRequest) returns (common.Status) {}
rpc DescribeIndex(DescribeIndexRequest) returns (DescribeIndexResponse) {}
rpc GetIndexStatistics(GetIndexStatisticsRequest) returns (GetIndexStatisticsResponse) {}
// Deprecated: use DescribeIndex instead
rpc GetIndexBuildProgress(GetIndexBuildProgressRequest) returns (GetIndexBuildProgressResponse) {}
rpc ShowConfigurations(internal.ShowConfigurationsRequest)
returns (internal.ShowConfigurationsResponse) {
@ -60,6 +60,13 @@ service IndexNode {
rpc GetMetrics(milvus.GetMetricsRequest)
returns (milvus.GetMetricsResponse) {
}
rpc CreateJobV2(CreateJobV2Request) returns (common.Status) {
}
rpc QueryJobsV2(QueryJobsV2Request) returns (QueryJobsV2Response) {
}
rpc DropJobsV2(DropJobsV2Request) returns (common.Status) {
}
}
message IndexInfo {
@ -159,9 +166,9 @@ message CreateIndexRequest {
}
message AlterIndexRequest {
int64 collectionID = 1;
string index_name = 2;
repeated common.KeyValuePair params = 3;
int64 collectionID = 1;
string index_name = 2;
repeated common.KeyValuePair params = 3;
}
message GetIndexInfoRequest {
@ -226,7 +233,7 @@ message GetIndexBuildProgressResponse {
int64 pending_index_rows = 4;
}
// Synchronously modify StorageConfig in index_cgo_msg.proto file
// Synchronously modify StorageConfig in index_cgo_msg.proto/clustering.proto file
message StorageConfig {
string address = 1;
string access_keyID = 2;
@ -246,11 +253,11 @@ message StorageConfig {
// Synchronously modify OptionalFieldInfo in index_cgo_msg.proto file
message OptionalFieldInfo {
int64 fieldID = 1;
string field_name = 2;
int32 field_type = 3;
repeated string data_paths = 4;
repeated int64 data_ids = 5;
int64 fieldID = 1;
string field_name = 2;
int32 field_type = 3;
repeated string data_paths = 4;
repeated int64 data_ids = 5;
}
message CreateJobRequest {
@ -347,3 +354,108 @@ message ListIndexesResponse {
common.Status status = 1;
repeated IndexInfo index_infos = 2;
}
message AnalyzeTask {
int64 collectionID = 1;
int64 partitionID = 2;
int64 fieldID = 3;
string field_name = 4;
schema.DataType field_type = 5;
int64 taskID = 6;
int64 version = 7;
repeated int64 segmentIDs = 8;
int64 nodeID = 9;
JobState state = 10;
string fail_reason = 11;
int64 dim = 12;
string centroids_file = 13;
}
message SegmentStats {
int64 ID = 1;
int64 num_rows = 2;
repeated int64 logIDs = 3;
}
message AnalyzeRequest {
string clusterID = 1;
int64 taskID = 2;
int64 collectionID = 3;
int64 partitionID = 4;
int64 fieldID = 5;
string fieldName = 6;
schema.DataType field_type = 7;
map<int64, SegmentStats> segment_stats = 8;
int64 version = 9;
StorageConfig storage_config = 10;
int64 dim = 11;
double max_train_size_ratio = 12;
int64 num_clusters = 13;
schema.FieldSchema field = 14;
double min_cluster_size_ratio = 15;
double max_cluster_size_ratio = 16;
int64 max_cluster_size = 17;
}
message AnalyzeResult {
int64 taskID = 1;
JobState state = 2;
string fail_reason = 3;
string centroids_file = 4;
}
enum JobType {
JobTypeNone = 0;
JobTypeIndexJob = 1;
JobTypeAnalyzeJob = 2;
}
message CreateJobV2Request {
string clusterID = 1;
int64 taskID = 2;
JobType job_type = 3;
oneof request {
AnalyzeRequest analyze_request = 4;
CreateJobRequest index_request = 5;
}
// JobDescriptor job = 3;
}
message QueryJobsV2Request {
string clusterID = 1;
repeated int64 taskIDs = 2;
JobType job_type = 3;
}
message IndexJobResults {
repeated IndexTaskInfo results = 1;
}
message AnalyzeResults {
repeated AnalyzeResult results = 1;
}
message QueryJobsV2Response {
common.Status status = 1;
string clusterID = 2;
oneof result {
IndexJobResults index_job_results = 3;
AnalyzeResults analyze_job_results = 4;
}
}
message DropJobsV2Request {
string clusterID = 1;
repeated int64 taskIDs = 2;
JobType job_type = 3;
}
enum JobState {
JobStateNone = 0;
JobStateInit = 1;
JobStateInProgress = 2;
JobStateFinished = 3;
JobStateFailed = 4;
JobStateRetry = 5;
}

View File

@ -0,0 +1,116 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package analyzecgowrapper
/*
#cgo pkg-config: milvus_clustering
#include <stdlib.h> // free
#include "clustering/analyze_c.h"
*/
import "C"
import (
"context"
"runtime"
"unsafe"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/clusteringpb"
"github.com/milvus-io/milvus/pkg/log"
)
type CodecAnalyze interface {
Delete() error
GetResult(size int) (string, int64, []string, []int64, error)
}
func Analyze(ctx context.Context, analyzeInfo *clusteringpb.AnalyzeInfo) (CodecAnalyze, error) {
analyzeInfoBlob, err := proto.Marshal(analyzeInfo)
if err != nil {
log.Ctx(ctx).Warn("marshal analyzeInfo failed",
zap.Int64("buildID", analyzeInfo.GetBuildID()),
zap.Error(err))
return nil, err
}
var analyzePtr C.CAnalyze
status := C.Analyze(&analyzePtr, (*C.uint8_t)(unsafe.Pointer(&analyzeInfoBlob[0])), (C.uint64_t)(len(analyzeInfoBlob)))
if err := HandleCStatus(&status, "failed to analyze task"); err != nil {
return nil, err
}
analyze := &CgoAnalyze{
analyzePtr: analyzePtr,
close: false,
}
runtime.SetFinalizer(analyze, func(ca *CgoAnalyze) {
if ca != nil && !ca.close {
log.Error("there is leakage in analyze object, please check.")
}
})
return analyze, nil
}
type CgoAnalyze struct {
analyzePtr C.CAnalyze
close bool
}
func (ca *CgoAnalyze) Delete() error {
if ca.close {
return nil
}
var status C.CStatus
if ca.analyzePtr != nil {
status = C.DeleteAnalyze(ca.analyzePtr)
}
ca.close = true
return HandleCStatus(&status, "failed to delete analyze")
}
func (ca *CgoAnalyze) GetResult(size int) (string, int64, []string, []int64, error) {
cOffsetMappingFilesPath := make([]unsafe.Pointer, size)
cOffsetMappingFilesSize := make([]C.int64_t, size)
cCentroidsFilePath := C.CString("")
cCentroidsFileSize := C.int64_t(0)
defer C.free(unsafe.Pointer(cCentroidsFilePath))
status := C.GetAnalyzeResultMeta(ca.analyzePtr,
&cCentroidsFilePath,
&cCentroidsFileSize,
unsafe.Pointer(&cOffsetMappingFilesPath[0]),
&cOffsetMappingFilesSize[0],
)
if err := HandleCStatus(&status, "failed to delete analyze"); err != nil {
return "", 0, nil, nil, err
}
offsetMappingFilesPath := make([]string, size)
offsetMappingFilesSize := make([]int64, size)
centroidsFilePath := C.GoString(cCentroidsFilePath)
centroidsFileSize := int64(cCentroidsFileSize)
for i := 0; i < size; i++ {
offsetMappingFilesPath[i] = C.GoString((*C.char)(cOffsetMappingFilesPath[i]))
offsetMappingFilesSize[i] = int64(cOffsetMappingFilesSize[i])
}
return centroidsFilePath, centroidsFileSize, offsetMappingFilesPath, offsetMappingFilesSize, nil
}

View File

@ -0,0 +1,55 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package analyzecgowrapper
/*
#cgo pkg-config: milvus_common
#include <stdlib.h> // free
#include "common/type_c.h"
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)
// HandleCStatus deal with the error returned from CGO
func HandleCStatus(status *C.CStatus, extraInfo string) error {
if status.error_code == 0 {
return nil
}
errorCode := int(status.error_code)
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, errorMsg)
log.Warn(logMsg)
if errorCode == 2003 {
return merr.WrapErrSegcoreUnsupported(int32(errorCode), logMsg)
}
if errorCode == 2033 {
log.Info("fake finished the task")
return merr.ErrSegcorePretendFinished
}
return merr.WrapErrSegcore(int32(errorCode), logMsg)
}

View File

@ -69,6 +69,18 @@ func (m *GrpcIndexNodeClient) ShowConfigurations(ctx context.Context, in *intern
return &internalpb.ShowConfigurationsResponse{}, m.Err
}
func (m *GrpcIndexNodeClient) CreateJobV2(ctx context.Context, in *indexpb.CreateJobV2Request, opt ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcIndexNodeClient) QueryJobsV2(ctx context.Context, in *indexpb.QueryJobsV2Request, opt ...grpc.CallOption) (*indexpb.QueryJobsV2Response, error) {
return &indexpb.QueryJobsV2Response{}, m.Err
}
func (m *GrpcIndexNodeClient) DropJobsV2(ctx context.Context, in *indexpb.DropJobsV2Request, opt ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcIndexNodeClient) Close() error {
return m.Err
}

View File

@ -93,6 +93,11 @@ const (
// PartitionStatsPath storage path const for partition stats files
PartitionStatsPath = `part_stats`
// AnalyzeStatsPath storage path const for analyze.
AnalyzeStatsPath = `analyze_stats`
OffsetMapping = `offset_mapping`
Centroids = "centroids"
)
// Search, Index parameter keys

View File

@ -248,6 +248,9 @@ type commonConfig struct {
BloomFilterType ParamItem `refreshable:"true"`
MaxBloomFalsePositive ParamItem `refreshable:"true"`
PanicWhenPluginFail ParamItem `refreshable:"false"`
UsePartitionKeyAsClusteringKey ParamItem `refreshable:"true"`
UseVectorAsClusteringKey ParamItem `refreshable:"true"`
}
func (p *commonConfig) init(base *BaseTable) {
@ -761,6 +764,22 @@ like the old password verification when updating the credential`,
Doc: "panic or not when plugin fail to init",
}
p.PanicWhenPluginFail.Init(base.mgr)
p.UsePartitionKeyAsClusteringKey = ParamItem{
Key: "common.usePartitionKeyAsClusteringKey",
Version: "2.4.2",
Doc: "if true, do clustering compaction and segment prune on partition key field",
DefaultValue: "false",
}
p.UsePartitionKeyAsClusteringKey.Init(base.mgr)
p.UseVectorAsClusteringKey = ParamItem{
Key: "common.useVectorAsClusteringKey",
Version: "2.4.2",
Doc: "if true, do clustering compaction and segment prune on vector field",
DefaultValue: "false",
}
p.UseVectorAsClusteringKey.Init(base.mgr)
}
type gpuConfig struct {
@ -2702,7 +2721,7 @@ user-task-polling:
p.DefaultSegmentFilterRatio = ParamItem{
Key: "queryNode.defaultSegmentFilterRatio",
Version: "2.4.0",
DefaultValue: "0.5",
DefaultValue: "2",
Doc: "filter ratio used for pruning segments when searching",
}
p.DefaultSegmentFilterRatio.Init(base.mgr)
@ -2772,6 +2791,26 @@ type dataCoordConfig struct {
ChannelCheckpointMaxLag ParamItem `refreshable:"true"`
SyncSegmentsInterval ParamItem `refreshable:"false"`
// Clustering Compaction
ClusteringCompactionEnable ParamItem `refreshable:"true"`
ClusteringCompactionAutoEnable ParamItem `refreshable:"true"`
ClusteringCompactionTriggerInterval ParamItem `refreshable:"true"`
ClusteringCompactionStateCheckInterval ParamItem `refreshable:"true"`
ClusteringCompactionGCInterval ParamItem `refreshable:"true"`
ClusteringCompactionMinInterval ParamItem `refreshable:"true"`
ClusteringCompactionMaxInterval ParamItem `refreshable:"true"`
ClusteringCompactionNewDataSizeThreshold ParamItem `refreshable:"true"`
ClusteringCompactionDropTolerance ParamItem `refreshable:"true"`
ClusteringCompactionPreferSegmentSize ParamItem `refreshable:"true"`
ClusteringCompactionMaxSegmentSize ParamItem `refreshable:"true"`
ClusteringCompactionMaxTrainSizeRatio ParamItem `refreshable:"true"`
ClusteringCompactionTimeoutInSeconds ParamItem `refreshable:"true"`
ClusteringCompactionMaxCentroidsNum ParamItem `refreshable:"true"`
ClusteringCompactionMinCentroidsNum ParamItem `refreshable:"true"`
ClusteringCompactionMinClusterSizeRatio ParamItem `refreshable:"true"`
ClusteringCompactionMaxClusterSizeRatio ParamItem `refreshable:"true"`
ClusteringCompactionMaxClusterSize ParamItem `refreshable:"true"`
// LevelZero Segment
EnableLevelZeroSegment ParamItem `refreshable:"false"`
LevelZeroCompactionTriggerMinSize ParamItem `refreshable:"true"`
@ -3172,6 +3211,156 @@ During compaction, the size of segment # of rows is able to exceed segment max #
}
p.LevelZeroCompactionTriggerDeltalogMaxNum.Init(base.mgr)
p.ClusteringCompactionEnable = ParamItem{
Key: "dataCoord.compaction.clustering.enable",
Version: "2.4.2",
DefaultValue: "false",
Doc: "Enable clustering compaction",
Export: true,
}
p.ClusteringCompactionEnable.Init(base.mgr)
p.ClusteringCompactionAutoEnable = ParamItem{
Key: "dataCoord.compaction.clustering.autoEnable",
Version: "2.4.2",
DefaultValue: "false",
Doc: "Enable auto clustering compaction",
Export: true,
}
p.ClusteringCompactionAutoEnable.Init(base.mgr)
p.ClusteringCompactionTriggerInterval = ParamItem{
Key: "dataCoord.compaction.clustering.triggerInterval",
Version: "2.4.2",
DefaultValue: "600",
}
p.ClusteringCompactionTriggerInterval.Init(base.mgr)
p.ClusteringCompactionStateCheckInterval = ParamItem{
Key: "dataCoord.compaction.clustering.stateCheckInterval",
Version: "2.4.2",
DefaultValue: "10",
}
p.ClusteringCompactionStateCheckInterval.Init(base.mgr)
p.ClusteringCompactionGCInterval = ParamItem{
Key: "dataCoord.compaction.clustering.gcInterval",
Version: "2.4.2",
DefaultValue: "600",
}
p.ClusteringCompactionGCInterval.Init(base.mgr)
p.ClusteringCompactionMinInterval = ParamItem{
Key: "dataCoord.compaction.clustering.minInterval",
Version: "2.4.2",
Doc: "The minimum interval between clustering compaction executions of one collection, to avoid redundant compaction",
DefaultValue: "3600",
}
p.ClusteringCompactionMinInterval.Init(base.mgr)
p.ClusteringCompactionMaxInterval = ParamItem{
Key: "dataCoord.compaction.clustering.maxInterval",
Version: "2.4.2",
Doc: "If a collection haven't been clustering compacted for longer than maxInterval, force compact",
DefaultValue: "86400",
}
p.ClusteringCompactionMaxInterval.Init(base.mgr)
p.ClusteringCompactionNewDataSizeThreshold = ParamItem{
Key: "dataCoord.compaction.clustering.newDataSizeThreshold",
Version: "2.4.2",
Doc: "If new data size is large than newDataSizeThreshold, execute clustering compaction",
DefaultValue: "512m",
}
p.ClusteringCompactionNewDataSizeThreshold.Init(base.mgr)
p.ClusteringCompactionTimeoutInSeconds = ParamItem{
Key: "dataCoord.compaction.clustering.timeout",
Version: "2.4.2",
DefaultValue: "3600",
}
p.ClusteringCompactionTimeoutInSeconds.Init(base.mgr)
p.ClusteringCompactionDropTolerance = ParamItem{
Key: "dataCoord.compaction.clustering.dropTolerance",
Version: "2.4.2",
Doc: "If clustering compaction job is finished for a long time, gc it",
DefaultValue: "259200",
}
p.ClusteringCompactionDropTolerance.Init(base.mgr)
p.ClusteringCompactionPreferSegmentSize = ParamItem{
Key: "dataCoord.compaction.clustering.preferSegmentSize",
Version: "2.4.2",
DefaultValue: "512m",
PanicIfEmpty: false,
Export: true,
}
p.ClusteringCompactionPreferSegmentSize.Init(base.mgr)
p.ClusteringCompactionMaxSegmentSize = ParamItem{
Key: "dataCoord.compaction.clustering.maxSegmentSize",
Version: "2.4.2",
DefaultValue: "1024m",
PanicIfEmpty: false,
Export: true,
}
p.ClusteringCompactionMaxSegmentSize.Init(base.mgr)
p.ClusteringCompactionMaxTrainSizeRatio = ParamItem{
Key: "dataCoord.compaction.clustering.maxTrainSizeRatio",
Version: "2.4.2",
DefaultValue: "0.8",
Doc: "max data size ratio in Kmeans train, if larger than it, will down sampling to meet this limit",
Export: true,
}
p.ClusteringCompactionMaxTrainSizeRatio.Init(base.mgr)
p.ClusteringCompactionMaxCentroidsNum = ParamItem{
Key: "dataCoord.compaction.clustering.maxCentroidsNum",
Version: "2.4.2",
DefaultValue: "10240",
Doc: "maximum centroids number in Kmeans train",
Export: true,
}
p.ClusteringCompactionMaxCentroidsNum.Init(base.mgr)
p.ClusteringCompactionMinCentroidsNum = ParamItem{
Key: "dataCoord.compaction.clustering.minCentroidsNum",
Version: "2.4.2",
DefaultValue: "16",
Doc: "minimum centroids number in Kmeans train",
Export: true,
}
p.ClusteringCompactionMinCentroidsNum.Init(base.mgr)
p.ClusteringCompactionMinClusterSizeRatio = ParamItem{
Key: "dataCoord.compaction.clustering.minClusterSizeRatio",
Version: "2.4.2",
DefaultValue: "0.01",
Doc: "minimum cluster size / avg size in Kmeans train",
Export: true,
}
p.ClusteringCompactionMinClusterSizeRatio.Init(base.mgr)
p.ClusteringCompactionMaxClusterSizeRatio = ParamItem{
Key: "dataCoord.compaction.clustering.maxClusterSizeRatio",
Version: "2.4.2",
DefaultValue: "10",
Doc: "maximum cluster size / avg size in Kmeans train",
Export: true,
}
p.ClusteringCompactionMaxClusterSizeRatio.Init(base.mgr)
p.ClusteringCompactionMaxClusterSize = ParamItem{
Key: "dataCoord.compaction.clustering.maxClusterSize",
Version: "2.4.2",
DefaultValue: "5g",
Doc: "maximum cluster size in Kmeans train",
Export: true,
}
p.ClusteringCompactionMaxClusterSize.Init(base.mgr)
p.EnableGarbageCollection = ParamItem{
Key: "dataCoord.enableGarbageCollection",
Version: "2.0.0",
@ -3475,6 +3664,10 @@ type dataNodeConfig struct {
// slot
SlotCap ParamItem `refreshable:"true"`
// clustering compaction
ClusteringCompactionMemoryBufferRatio ParamItem `refreshable:"true"`
ClusteringCompactionWorkerPoolSize ParamItem `refreshable:"true"`
}
func (p *dataNodeConfig) init(base *BaseTable) {
@ -3789,6 +3982,26 @@ if this parameter <= 0, will set it as 10`,
Export: true,
}
p.SlotCap.Init(base.mgr)
p.ClusteringCompactionMemoryBufferRatio = ParamItem{
Key: "datanode.clusteringCompaction.memoryBufferRatio",
Version: "2.4.2",
Doc: "The ratio of memory buffer of clustering compaction. Data larger than threshold will be spilled to storage.",
DefaultValue: "0.1",
PanicIfEmpty: false,
Export: true,
}
p.ClusteringCompactionMemoryBufferRatio.Init(base.mgr)
p.ClusteringCompactionWorkerPoolSize = ParamItem{
Key: "datanode.clusteringCompaction.cpu",
Version: "2.4.2",
Doc: "worker pool size for one clustering compaction job.",
DefaultValue: "1",
PanicIfEmpty: false,
Export: true,
}
p.ClusteringCompactionWorkerPoolSize.Init(base.mgr)
}
// /////////////////////////////////////////////////////////////////////////////

View File

@ -429,6 +429,23 @@ func TestComponentParam(t *testing.T) {
params.Save("datacoord.gracefulStopTimeout", "100")
assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second))
params.Save("dataCoord.compaction.clustering.enable", "true")
assert.Equal(t, true, Params.ClusteringCompactionEnable.GetAsBool())
params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10")
assert.Equal(t, int64(10), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize())
params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10k")
assert.Equal(t, int64(10*1024), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize())
params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10m")
assert.Equal(t, int64(10*1024*1024), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize())
params.Save("dataCoord.compaction.clustering.newDataSizeThreshold", "10g")
assert.Equal(t, int64(10*1024*1024*1024), Params.ClusteringCompactionNewDataSizeThreshold.GetAsSize())
params.Save("dataCoord.compaction.clustering.dropTolerance", "86400")
assert.Equal(t, int64(86400), Params.ClusteringCompactionDropTolerance.GetAsInt64())
params.Save("dataCoord.compaction.clustering.maxSegmentSize", "100m")
assert.Equal(t, int64(100*1024*1024), Params.ClusteringCompactionMaxSegmentSize.GetAsSize())
params.Save("dataCoord.compaction.clustering.preferSegmentSize", "10m")
assert.Equal(t, int64(10*1024*1024), Params.ClusteringCompactionPreferSegmentSize.GetAsSize())
})
t.Run("test dataNodeConfig", func(t *testing.T) {
@ -482,6 +499,9 @@ func TestComponentParam(t *testing.T) {
params.Save("datanode.gracefulStopTimeout", "100")
assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second))
assert.Equal(t, 2, Params.SlotCap.GetAsInt())
// clustering compaction
params.Save("datanode.clusteringCompaction.memoryBufferRatio", "0.1")
assert.Equal(t, 0.1, Params.ClusteringCompactionMemoryBufferRatio.GetAsFloat())
})
t.Run("test indexNodeConfig", func(t *testing.T) {
@ -500,6 +520,14 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, "by-dev-dml1", Params.RootCoordDml.GetValue())
})
t.Run("clustering compaction config", func(t *testing.T) {
Params := &params.CommonCfg
params.Save("common.usePartitionKeyAsClusteringKey", "true")
assert.Equal(t, true, Params.UsePartitionKeyAsClusteringKey.GetAsBool())
params.Save("common.useVectorAsClusteringKey", "true")
assert.Equal(t, true, Params.UseVectorAsClusteringKey.GetAsBool())
})
}
func TestForbiddenItem(t *testing.T) {

View File

@ -378,6 +378,20 @@ func IsDenseFloatVectorType(dataType schemapb.DataType) bool {
}
}
// return VectorTypeSize for each dim (byte)
func VectorTypeSize(dataType schemapb.DataType) float64 {
switch dataType {
case schemapb.DataType_FloatVector, schemapb.DataType_SparseFloatVector:
return 4.0
case schemapb.DataType_BinaryVector:
return 0.125
case schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
return 2.0
default:
return 0.0
}
}
func IsSparseFloatVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_SparseFloatVector
}
@ -1085,6 +1099,16 @@ func IsFieldDataTypeSupportMaterializedView(fieldSchema *schemapb.FieldSchema) b
return fieldSchema.DataType == schemapb.DataType_VarChar || fieldSchema.DataType == schemapb.DataType_String
}
// HasClusterKey check if a collection schema has ClusterKey field
func HasClusterKey(schema *schemapb.CollectionSchema) bool {
for _, fieldSchema := range schema.Fields {
if fieldSchema.IsClusteringKey {
return true
}
}
return false
}
// GetPrimaryFieldData get primary field data from all field data inserted from sdk
func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schemapb.FieldSchema) (*schemapb.FieldData, error) {
primaryFieldID := primaryFieldSchema.FieldID

View File

@ -1047,6 +1047,40 @@ func TestGetPrimaryFieldSchema(t *testing.T) {
assert.True(t, hasPartitionKey2)
}
func TestGetClusterKeyFieldSchema(t *testing.T) {
int64Field := &schemapb.FieldSchema{
FieldID: 1,
Name: "int64Field",
DataType: schemapb.DataType_Int64,
}
clusterKeyfloatField := &schemapb.FieldSchema{
FieldID: 2,
Name: "floatField",
DataType: schemapb.DataType_Float,
IsClusteringKey: true,
}
unClusterKeyfloatField := &schemapb.FieldSchema{
FieldID: 2,
Name: "floatField",
DataType: schemapb.DataType_Float,
IsClusteringKey: false,
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{int64Field, clusterKeyfloatField},
}
schema2 := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{int64Field, unClusterKeyfloatField},
}
hasClusterKey1 := HasClusterKey(schema)
assert.True(t, hasClusterKey1)
hasClusterKey2 := HasClusterKey(schema2)
assert.False(t, hasClusterKey2)
}
func TestGetPK(t *testing.T) {
type args struct {
data *schemapb.IDs

View File

@ -50,6 +50,7 @@ mkdir -p internalpb
mkdir -p rootcoordpb
mkdir -p segcorepb
mkdir -p clusteringpb
mkdir -p proxypb
mkdir -p indexpb
@ -72,6 +73,7 @@ ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./datapb data_coord.pr
${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./querypb query_coord.proto|| { echo 'generate query_coord.proto failed'; exit 1; }
${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./planpb plan.proto|| { echo 'generate plan.proto failed'; exit 1; }
${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./segcorepb segcore.proto|| { echo 'generate segcore.proto failed'; exit 1; }
${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./clusteringpb clustering.proto|| { echo 'generate clustering.proto failed'; exit 1; }
${protoc_opt} --proto_path=$ROOT_DIR/cmd/tools/migration/legacy/ \
--go_out=plugins=grpc,paths=source_relative:../../cmd/tools/migration/legacy/legacypb legacy.proto || { echo 'generate legacy.proto failed'; exit 1; }
@ -79,10 +81,9 @@ ${protoc_opt} --proto_path=$ROOT_DIR/cmd/tools/migration/legacy/ \
${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb schema.proto|| { echo 'generate schema.proto failed'; exit 1; }
${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb common.proto|| { echo 'generate common.proto failed'; exit 1; }
${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb segcore.proto|| { echo 'generate segcore.proto failed'; exit 1; }
${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb clustering.proto|| { echo 'generate clustering.proto failed'; exit 1; }
${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb index_cgo_msg.proto|| { echo 'generate index_cgo_msg.proto failed'; exit 1; }
${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb cgo_msg.proto|| { echo 'generate cgo_msg.proto failed'; exit 1; }
${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb plan.proto|| { echo 'generate plan.proto failed'; exit 1; }
popd

View File

@ -1,6 +1,8 @@
module github.com/milvus-io/milvus/tests/go_client
go 1.20
go 1.21
toolchain go1.21.10
require (
github.com/milvus-io/milvus/client/v2 v2.0.0-20240521081339-017fd7bc25de