milvus/internal/core/src/indexbuilder/VecIndexCreator.cpp
Cai Yudong f4ebd3a9ce
Upgrade to knowhere v1.1.0 (#16186)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
2022-03-25 13:49:25 +08:00

360 lines
14 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <exception>
#include <map>
#include <google/protobuf/text_format.h>
#include "exceptions/EasyAssert.h"
#include "pb/index_cgo_msg.pb.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/utils.h"
#include "knowhere/common/Timer.h"
#include "knowhere/common/Utils.h"
#include "knowhere/index/vector_index/ConfAdapterMgr.h"
#include "knowhere/index/vector_index/VecIndexFactory.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
namespace milvus::indexbuilder {
VecIndexCreator::VecIndexCreator(const char* serialized_type_params, const char* serialized_index_params) {
type_params_ = std::string(serialized_type_params);
index_params_ = std::string(serialized_index_params);
parse();
auto index_mode = get_index_mode();
auto index_type = get_index_type();
auto metric_type = get_metric_type();
AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type);
index_ = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(get_index_type(), index_mode);
AssertInfo(index_ != nullptr, "[VecIndexCreator]Index is null after create index");
}
template <typename ParamsT>
// ugly here, ParamsT will just be MapParams later
void
VecIndexCreator::parse_impl(const std::string& serialized_params_str, knowhere::Config& conf) {
bool deserialized_success;
ParamsT params;
deserialized_success = google::protobuf::TextFormat::ParseFromString(serialized_params_str, &params);
AssertInfo(deserialized_success, "[VecIndexCreator]Deserialize params failed");
for (auto i = 0; i < params.params_size(); ++i) {
const auto& param = params.params(i);
const auto& key = param.key();
const auto& value = param.value();
conf[key] = value;
}
auto stoi_closure = [](const std::string& s) -> auto {
return std::stoi(s);
};
auto stof_closure = [](const std::string& s) -> auto {
return std::stof(s);
};
/***************************** meta *******************************/
check_parameter<int>(conf, knowhere::meta::DIM, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::meta::TOPK, stoi_closure, std::nullopt);
/***************************** IVF Params *******************************/
check_parameter<int>(conf, knowhere::IndexParams::nprobe, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::nlist, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::m, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::nbits, stoi_closure, std::nullopt);
/************************** NSG Parameter **************************/
check_parameter<int>(conf, knowhere::IndexParams::knng, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::search_length, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::out_degree, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::candidate, stoi_closure, std::nullopt);
/************************** HNSW Params *****************************/
check_parameter<int>(conf, knowhere::IndexParams::efConstruction, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::M, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::ef, stoi_closure, std::nullopt);
/************************** Annoy Params *****************************/
check_parameter<int>(conf, knowhere::IndexParams::n_trees, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::search_k, stoi_closure, std::nullopt);
/************************** PQ Params *****************************/
check_parameter<int>(conf, knowhere::IndexParams::PQM, stoi_closure, std::nullopt);
/************************** NGT Params *****************************/
check_parameter<int>(conf, knowhere::IndexParams::edge_size, stoi_closure, std::nullopt);
/************************** NGT Search Params *****************************/
check_parameter<float>(conf, knowhere::IndexParams::epsilon, stof_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::max_search_edges, stoi_closure, std::nullopt);
/************************** NGT_PANNG Params *****************************/
check_parameter<int>(conf, knowhere::IndexParams::forcedly_pruned_edge_size, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::selectively_pruned_edge_size, stoi_closure, std::nullopt);
/************************** NGT_ONNG Params *****************************/
check_parameter<int>(conf, knowhere::IndexParams::outgoing_edge_size, stoi_closure, std::nullopt);
check_parameter<int>(conf, knowhere::IndexParams::incoming_edge_size, stoi_closure, std::nullopt);
/************************** Serialize Params *******************************/
check_parameter<int>(conf, knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, stoi_closure, std::optional{4});
}
void
VecIndexCreator::parse() {
namespace indexcgo = milvus::proto::indexcgo;
parse_impl<indexcgo::TypeParams>(type_params_, type_config_);
parse_impl<indexcgo::IndexParams>(index_params_, index_config_);
config_.update(type_config_); // just like dict().update in Python, amazing
config_.update(index_config_);
}
template <typename T>
void
VecIndexCreator::check_parameter(knowhere::Config& conf,
const std::string& key,
std::function<T(std::string)> fn,
std::optional<T> default_v) {
if (!conf.contains(key)) {
if (default_v.has_value()) {
conf[key] = default_v.value();
}
} else {
auto value = conf[key];
conf[key] = fn(value);
}
}
template <typename T>
std::optional<T>
VecIndexCreator::get_config_by_name(std::string name) {
if (config_.contains(name)) {
return {config_[name].get<T>()};
}
return std::nullopt;
}
int64_t
VecIndexCreator::dim() {
auto dimension = get_config_by_name<int64_t>(knowhere::meta::DIM);
AssertInfo(dimension.has_value(), "[VecIndexCreator]Dimension doesn't have value");
return (dimension.value());
}
void
VecIndexCreator::BuildWithoutIds(const knowhere::DatasetPtr& dataset) {
auto index_type = get_index_type();
auto index_mode = get_index_mode();
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
if (!config_.contains(knowhere::IndexParams::nbits)) {
config_[knowhere::IndexParams::nbits] = 8;
}
}
auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type);
std::cout << "Konwhere BuildWithoutIds config_ is " << config_ << std::endl;
AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!");
if (is_in_need_id_list(index_type)) {
PanicInfo(std::string(index_type) + " doesn't support build without ids yet!");
}
knowhere::TimeRecorder rc("BuildWithoutIds", 1);
// if (is_in_need_build_all_list(index_type)) {
// index_->BuildAll(dataset, config_);
// } else {
// index_->Train(dataset, config_);
// index_->AddWithoutIds(dataset, config_);
// }
index_->BuildAll(dataset, config_);
rc.RecordSection("TrainAndAdd");
if (is_in_nm_list(index_type)) {
StoreRawData(dataset);
rc.RecordSection("StoreRawData");
}
rc.ElapseFromBegin("Done");
}
void
VecIndexCreator::BuildWithIds(const knowhere::DatasetPtr& dataset) {
AssertInfo(dataset->data().find(knowhere::meta::IDS) != dataset->data().end(),
"[VecIndexCreator]Can't find ids field in dataset");
auto index_type = get_index_type();
auto index_mode = get_index_mode();
config_[knowhere::meta::ROWS] = dataset->Get<int64_t>(knowhere::meta::ROWS);
if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
if (!config_.contains(knowhere::IndexParams::nbits)) {
config_[knowhere::IndexParams::nbits] = 8;
}
}
auto conf_adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type);
AssertInfo(conf_adapter->CheckTrain(config_, index_mode), "something wrong in index parameters!");
// index_->Train(dataset, config_);
// index_->Add(dataset, config_);
index_->BuildAll(dataset, config_);
if (is_in_nm_list(get_index_type())) {
StoreRawData(dataset);
}
}
void
VecIndexCreator::StoreRawData(const knowhere::DatasetPtr& dataset) {
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
auto tensor = dataset->Get<const void*>(knowhere::meta::TENSOR);
auto row_num = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto dim = dataset->Get<int64_t>(knowhere::meta::DIM);
int64_t data_size;
if (is_in_bin_list(index_type)) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
}
raw_data_.resize(data_size);
memcpy(raw_data_.data(), tensor, data_size);
}
}
knowhere::BinarySet
VecIndexCreator::Serialize() {
auto ret = index_->Serialize(config_);
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
std::shared_ptr<uint8_t[]> raw_data(new uint8_t[raw_data_.size()], std::default_delete<uint8_t[]>());
memcpy(raw_data.get(), raw_data_.data(), raw_data_.size());
ret.Append(RAW_DATA, raw_data, raw_data_.size());
auto slice_size = get_index_file_slice_size();
// https://github.com/milvus-io/milvus/issues/6421
// Disassemble will only divide the raw vectors, other keys were already divided
knowhere::Disassemble(slice_size * 1024 * 1024, ret);
}
return ret;
}
void
VecIndexCreator::Load(const knowhere::BinarySet& binary_set) {
auto& map_ = binary_set.binary_map_;
for (auto it = map_.begin(); it != map_.end(); ++it) {
if (it->first == RAW_DATA) {
raw_data_.clear();
auto data_size = it->second->size;
raw_data_.resize(data_size);
memcpy(raw_data_.data(), it->second->data.get(), data_size);
break;
}
}
index_->Load(binary_set);
}
std::string
VecIndexCreator::get_index_type() {
// return index_->index_type();
// knowhere bug here
// the index_type of all ivf-based index will change to ivf flat after loaded
auto type = get_config_by_name<std::string>("index_type");
return type.has_value() ? type.value() : knowhere::IndexEnum::INDEX_FAISS_IVFPQ;
}
std::string
VecIndexCreator::get_metric_type() {
auto type = get_config_by_name<std::string>(knowhere::Metric::TYPE);
if (type.has_value()) {
return type.value();
} else {
auto index_type = get_index_type();
if (is_in_bin_list(index_type)) {
return knowhere::Metric::JACCARD;
} else {
return knowhere::Metric::L2;
}
}
}
knowhere::IndexMode
VecIndexCreator::get_index_mode() {
static std::map<std::string, knowhere::IndexMode> mode_map = {
{"CPU", knowhere::IndexMode::MODE_CPU},
{"GPU", knowhere::IndexMode::MODE_GPU},
};
auto mode = get_config_by_name<std::string>("index_mode");
return mode.has_value() ? mode_map[mode.value()] : knowhere::IndexMode::MODE_CPU;
}
int64_t
VecIndexCreator::get_index_file_slice_size() {
if (config_.contains(knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE)) {
return config_[knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE].get<int64_t>();
}
return 4; // by default
}
std::unique_ptr<VecIndexCreator::QueryResult>
VecIndexCreator::Query(const knowhere::DatasetPtr& dataset) {
return std::move(QueryImpl(dataset, config_));
}
std::unique_ptr<VecIndexCreator::QueryResult>
VecIndexCreator::QueryWithParam(const knowhere::DatasetPtr& dataset, const char* serialized_search_params) {
namespace indexcgo = milvus::proto::indexcgo;
knowhere::Config search_conf;
parse_impl<indexcgo::MapParams>(std::string(serialized_search_params), search_conf);
return std::move(QueryImpl(dataset, search_conf));
}
std::unique_ptr<VecIndexCreator::QueryResult>
VecIndexCreator::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Config& conf) {
auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
std::call_once(raw_data_loaded_, load_raw_data_closure);
}
auto res = index_->Query(dataset, conf, nullptr);
auto ids = res->Get<int64_t*>(knowhere::meta::IDS);
auto distances = res->Get<float*>(knowhere::meta::DISTANCE);
auto nq = dataset->Get<int64_t>(knowhere::meta::ROWS);
auto k = config_[knowhere::meta::TOPK].get<int64_t>();
auto query_res = std::make_unique<VecIndexCreator::QueryResult>();
query_res->nq = nq;
query_res->topk = k;
query_res->ids.resize(nq * k);
query_res->distances.resize(nq * k);
memcpy(query_res->ids.data(), ids, sizeof(int64_t) * nq * k);
memcpy(query_res->distances.data(), distances, sizeof(float) * nq * k);
return std::move(query_res);
}
void
VecIndexCreator::LoadRawData() {
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
auto bs = index_->Serialize(config_);
auto bptr = std::make_shared<knowhere::Binary>();
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
bptr->data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter);
bptr->size = raw_data_.size();
bs.Append(RAW_DATA, bptr);
index_->Load(bs);
}
}
} // namespace milvus::indexbuilder