milvus/internal/core/src/index/VectorMemIndex.cpp
yah01 227d2c8b3a
Reduce loading index memory usage (#25698)
Signed-off-by: yah01 <yang.cen@zilliz.com>
2023-07-19 14:02:57 +08:00

296 lines
10 KiB
C++

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "index/VectorMemIndex.h"
#include <cmath>
#include <memory>
#include <string>
#include <unordered_map>
#include "fmt/format.h"
#include "index/Meta.h"
#include "index/Utils.h"
#include "exceptions/EasyAssert.h"
#include "config/ConfigKnowhere.h"
#include "knowhere/factory.h"
#include "knowhere/comp/time_recorder.h"
#include "common/BitsetView.h"
#include "common/Slice.h"
#include "common/Consts.h"
#include "common/RangeSearchHelper.h"
#include "common/Utils.h"
#include "storage/FieldData.h"
#include "storage/MemFileManagerImpl.h"
#include "storage/ThreadPool.h"
namespace milvus::index {
VectorMemIndex::VectorMemIndex(const IndexType& index_type,
const MetricType& metric_type,
storage::FileManagerImplPtr file_manager)
: VectorIndex(index_type, metric_type) {
AssertInfo(!is_unsupported(index_type, metric_type),
index_type + " doesn't support metric: " + metric_type);
if (file_manager != nullptr) {
file_manager_ = std::dynamic_pointer_cast<storage::MemFileManagerImpl>(
file_manager);
}
index_ = knowhere::IndexFactory::Instance().Create(GetIndexType());
}
BinarySet
VectorMemIndex::Upload(const Config& config) {
auto binary_set = Serialize(config);
file_manager_->AddFile(binary_set);
auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize();
BinarySet ret;
for (auto& file : remote_paths_to_size) {
ret.Append(file.first, nullptr, file.second);
}
return ret;
}
BinarySet
VectorMemIndex::Serialize(const Config& config) {
knowhere::BinarySet ret;
auto stat = index_.Serialize(ret);
if (stat != knowhere::Status::success)
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to serialize index, " + KnowhereStatusString(stat));
Disassemble(ret);
return ret;
}
void
VectorMemIndex::LoadWithoutAssemble(const BinarySet& binary_set,
const Config& config) {
auto stat = index_.Deserialize(binary_set);
if (stat != knowhere::Status::success)
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to Deserialize index, " + KnowhereStatusString(stat));
SetDim(index_.Dim());
}
void
VectorMemIndex::Load(const BinarySet& binary_set, const Config& config) {
milvus::Assemble(const_cast<BinarySet&>(binary_set));
LoadWithoutAssemble(binary_set, config);
}
void
VectorMemIndex::Load(const Config& config) {
auto index_files =
GetValueFromConfig<std::vector<std::string>>(config, "index_files");
AssertInfo(index_files.has_value(),
"index file paths is empty when load index");
std::map<std::string, storage::FieldDataChannelPtr> channels;
for (const auto& file : index_files.value()) {
auto key = file.substr(file.find_last_of('/') + 1);
if (channels.find(key) == channels.end()) {
channels.emplace(std::move(key),
std::make_shared<storage::FieldDataChannel>());
}
}
auto& pool = ThreadPool::GetInstance();
auto future = pool.Submit(
[&] { file_manager_->LoadFileStream(index_files.value(), channels); });
std::unordered_map<std::string, storage::FieldDataPtr> result;
AssembleIndexDatas(channels, result);
BinarySet binary_set;
for (auto& [key, data] : result) {
auto size = data->Size();
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
auto buf = std::shared_ptr<uint8_t[]>(
(uint8_t*)const_cast<void*>(data->Data()), deleter);
binary_set.Append(key, buf, size);
}
LoadWithoutAssemble(binary_set, config);
}
void
VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset,
const Config& config) {
knowhere::Json index_config;
index_config.update(config);
SetDim(dataset->GetDim());
knowhere::TimeRecorder rc("BuildWithoutIds", 1);
auto stat = index_.Build(*dataset, index_config);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::BuildIndexError,
"failed to build index, " + KnowhereStatusString(stat));
rc.ElapseFromBegin("Done");
SetDim(index_.Dim());
}
void
VectorMemIndex::Build(const Config& config) {
auto insert_files =
GetValueFromConfig<std::vector<std::string>>(config, "insert_files");
AssertInfo(insert_files.has_value(),
"insert file paths is empty when build disk ann index");
auto field_datas =
file_manager_->CacheRawDataToMemory(insert_files.value());
int64_t total_size = 0;
int64_t total_num_rows = 0;
int64_t dim = 0;
for (auto data : field_datas) {
total_size += data->Size();
total_num_rows += data->get_num_rows();
AssertInfo(dim == 0 || dim == data->get_dim(),
"inconsistent dim value between field datas!");
dim = data->get_dim();
}
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
int64_t offset = 0;
for (auto data : field_datas) {
std::memcpy(buf.get() + offset, data->Data(), data->Size());
offset += data->Size();
data.reset();
}
field_datas.clear();
Config build_config;
build_config.update(config);
build_config.erase("insert_files");
auto dataset = GenDataset(total_num_rows, dim, buf.get());
BuildWithDataset(dataset, build_config);
}
void
VectorMemIndex::AddWithDataset(const DatasetPtr& dataset,
const Config& config) {
knowhere::Json index_config;
index_config.update(config);
knowhere::TimeRecorder rc("AddWithDataset", 1);
auto stat = index_.Add(*dataset, index_config);
if (stat != knowhere::Status::success)
PanicCodeInfo(ErrorCodeEnum::BuildIndexError,
"failed to append index, " + KnowhereStatusString(stat));
rc.ElapseFromBegin("Done");
}
std::unique_ptr<SearchResult>
VectorMemIndex::Query(const DatasetPtr dataset,
const SearchInfo& search_info,
const BitsetView& bitset) {
// AssertInfo(GetMetricType() == search_info.metric_type_,
// "Metric type of field index isn't the same with search info");
auto num_queries = dataset->GetRows();
knowhere::Json search_conf = search_info.search_params_;
auto topk = search_info.topk_;
// TODO :: check dim of search data
auto final = [&] {
search_conf[knowhere::meta::TOPK] = topk;
search_conf[knowhere::meta::METRIC_TYPE] = GetMetricType();
auto index_type = GetIndexType();
if (CheckKeyInConfig(search_conf, RADIUS)) {
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) {
CheckRangeSearchParam(search_conf[RADIUS],
search_conf[RANGE_FILTER],
GetMetricType());
}
auto res = index_.RangeSearch(*dataset, search_conf, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
fmt::format("failed to range search: {}: {}",
KnowhereStatusString(res.error()),
res.what()));
}
return ReGenRangeSearchResult(
res.value(), topk, num_queries, GetMetricType());
} else {
auto res = index_.Search(*dataset, search_conf, bitset);
if (!res.has_value()) {
PanicCodeInfo(ErrorCodeEnum::UnexpectedError,
fmt::format("failed to search: {}: {}",
KnowhereStatusString(res.error()),
res.what()));
}
return res.value();
}
}();
auto ids = final->GetIds();
float* distances = const_cast<float*>(final->GetDistance());
final->SetIsOwner(true);
auto round_decimal = search_info.round_decimal_;
auto total_num = num_queries * topk;
if (round_decimal != -1) {
const float multiplier = pow(10.0, round_decimal);
for (int i = 0; i < total_num; i++) {
distances[i] = std::round(distances[i] * multiplier) / multiplier;
}
}
auto result = std::make_unique<SearchResult>();
result->seg_offsets_.resize(total_num);
result->distances_.resize(total_num);
result->total_nq_ = num_queries;
result->unity_topK_ = topk;
std::copy_n(ids, total_num, result->seg_offsets_.data());
std::copy_n(distances, total_num, result->distances_.data());
return result;
}
const bool
VectorMemIndex::HasRawData() const {
return index_.HasRawData(GetMetricType());
}
std::vector<uint8_t>
VectorMemIndex::GetVector(const DatasetPtr dataset) const {
auto res = index_.GetVectorByIds(*dataset);
if (!res.has_value()) {
PanicCodeInfo(
ErrorCodeEnum::UnexpectedError,
"failed to get vector, " + KnowhereStatusString(res.error()));
}
auto index_type = GetIndexType();
auto tensor = res.value()->GetTensor();
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
int64_t data_size;
if (is_in_bin_list(index_type)) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
memcpy(raw_data.data(), tensor, data_size);
return raw_data;
}
} // namespace milvus::index