mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-04 18:02:08 +08:00
Merge branch 'be_stable' into 'branch-0.3.1-xiaojun'
MS-259 Be stable See merge request megasearch/milvus!254 Former-commit-id: af719d39a78af40fbeb44490eb4c3233da107cd4
This commit is contained in:
commit
2d329aad68
@ -3,6 +3,8 @@
|
||||
* Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include <stdexcept>
|
||||
|
||||
#include <src/server/ServerConfig.h>
|
||||
#include <src/metrics/Metrics.h>
|
||||
#include "Log.h"
|
||||
@ -11,6 +13,8 @@
|
||||
#include "ExecutionEngineImpl.h"
|
||||
#include "wrapper/knowhere/vec_index.h"
|
||||
#include "wrapper/knowhere/vec_impl.h"
|
||||
#include "knowhere/common/exception.h"
|
||||
#include "Exception.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
@ -21,9 +25,13 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension,
|
||||
const std::string &location,
|
||||
EngineType type)
|
||||
: location_(location), dim(dimension), build_type(type) {
|
||||
index_ = CreatetVecIndex(EngineType::FAISS_IDMAP);
|
||||
current_type = EngineType::FAISS_IDMAP;
|
||||
std::static_pointer_cast<BFIndex>(index_)->Build(dimension);
|
||||
|
||||
index_ = CreatetVecIndex(EngineType::FAISS_IDMAP);
|
||||
if (!index_) throw Exception("Create Empty VecIndex");
|
||||
|
||||
auto ec = std::static_pointer_cast<BFIndex>(index_)->Build(dimension);
|
||||
if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); }
|
||||
}
|
||||
|
||||
ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index,
|
||||
@ -61,7 +69,10 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
|
||||
}
|
||||
|
||||
Status ExecutionEngineImpl::AddWithIds(long n, const float *xdata, const long *xids) {
|
||||
index_->Add(n, xdata, xids, Config::object{{"dim", dim}});
|
||||
auto ec = index_->Add(n, xdata, xids, Config::object{{"dim", dim}});
|
||||
if (ec != server::KNOWHERE_SUCCESS) {
|
||||
return Status::Error("Add error");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -82,7 +93,10 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
|
||||
}
|
||||
|
||||
Status ExecutionEngineImpl::Serialize() {
|
||||
write_index(index_, location_);
|
||||
auto ec = write_index(index_, location_);
|
||||
if (ec != server::KNOWHERE_SUCCESS) {
|
||||
return Status::Error("Serialize: write to disk error");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -91,9 +105,16 @@ Status ExecutionEngineImpl::Load() {
|
||||
bool to_cache = false;
|
||||
auto start_time = METRICS_NOW_TIME;
|
||||
if (!index_) {
|
||||
index_ = read_index(location_);
|
||||
to_cache = true;
|
||||
ENGINE_LOG_DEBUG << "Disk io from: " << location_;
|
||||
try {
|
||||
index_ = read_index(location_);
|
||||
to_cache = true;
|
||||
ENGINE_LOG_DEBUG << "Disk io from: " << location_;
|
||||
} catch (knowhere::KnowhereException &e) {
|
||||
ENGINE_LOG_ERROR << e.what();
|
||||
return Status::Error(e.what());
|
||||
} catch (std::exception &e) {
|
||||
return Status::Error(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
if (to_cache) {
|
||||
@ -118,11 +139,22 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
|
||||
|
||||
auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location);
|
||||
if (!to_merge) {
|
||||
to_merge = read_index(location);
|
||||
try {
|
||||
to_merge = read_index(location);
|
||||
} catch (knowhere::KnowhereException &e) {
|
||||
ENGINE_LOG_ERROR << e.what();
|
||||
return Status::Error(e.what());
|
||||
} catch (std::exception &e) {
|
||||
return Status::Error(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
|
||||
index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
|
||||
auto ec = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
|
||||
if (ec != server::KNOWHERE_SUCCESS) {
|
||||
ENGINE_LOG_ERROR << "Merge: Add Error";
|
||||
return Status::Error("Merge: Add Error");
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status::Error("file index type is not idmap");
|
||||
@ -134,13 +166,16 @@ ExecutionEngineImpl::BuildIndex(const std::string &location) {
|
||||
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
|
||||
|
||||
auto from_index = std::dynamic_pointer_cast<BFIndex>(index_);
|
||||
ENGINE_LOG_DEBUG << "BuildIndex EngineTypee: " << int(build_type);
|
||||
auto to_index = CreatetVecIndex(build_type);
|
||||
ENGINE_LOG_DEBUG << "Build Params: [gpu_id] " << gpu_num;
|
||||
to_index->BuildAll(Count(),
|
||||
from_index->GetRawVectors(),
|
||||
from_index->GetRawIds(),
|
||||
Config::object{{"dim", Dimension()}, {"gpu_id", gpu_num}});
|
||||
if (!to_index) {
|
||||
throw Exception("Create Empty VecIndex");
|
||||
}
|
||||
|
||||
auto ec = to_index->BuildAll(Count(),
|
||||
from_index->GetRawVectors(),
|
||||
from_index->GetRawIds(),
|
||||
Config::object{{"dim", Dimension()}, {"gpu_id", gpu_num}});
|
||||
if (ec != server::KNOWHERE_SUCCESS) { throw Exception("Build index error"); }
|
||||
|
||||
return std::make_shared<ExecutionEngineImpl>(to_index, location, build_type);
|
||||
}
|
||||
@ -151,7 +186,11 @@ Status ExecutionEngineImpl::Search(long n,
|
||||
float *distances,
|
||||
long *labels) const {
|
||||
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe_;
|
||||
index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe_}});
|
||||
auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe_}});
|
||||
if (ec != server::KNOWHERE_SUCCESS) {
|
||||
ENGINE_LOG_ERROR << "Search error";
|
||||
return Status::Error("Search: Search Error");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -54,6 +54,12 @@ constexpr ServerError SERVER_LICENSE_VALIDATION_FAIL = ToGlobalServerErrorCode(5
|
||||
|
||||
constexpr ServerError DB_META_TRANSACTION_FAILED = ToGlobalServerErrorCode(1000);
|
||||
|
||||
using KnowhereError = int32_t;
|
||||
constexpr KnowhereError KNOWHERE_SUCCESS = 0;
|
||||
constexpr KnowhereError KNOWHERE_ERROR = ToGlobalServerErrorCode(1);
|
||||
constexpr KnowhereError KNOWHERE_INVALID_ARGUMENT = ToGlobalServerErrorCode(2);
|
||||
constexpr KnowhereError KNOWHERE_UNEXPECTED_ERROR = ToGlobalServerErrorCode(3);
|
||||
|
||||
class ServerException : public std::exception {
|
||||
public:
|
||||
ServerException(ServerError error_code,
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <src/utils/Log.h>
|
||||
#include "knowhere/index/vector_index/idmap.h"
|
||||
#include "knowhere/index/vector_index/gpu_ivf.h"
|
||||
#include "knowhere/common/exception.h"
|
||||
|
||||
#include "vec_impl.h"
|
||||
#include "data_transfer.h"
|
||||
@ -19,77 +20,110 @@ namespace engine {
|
||||
|
||||
using namespace zilliz::knowhere;
|
||||
|
||||
void VecIndexImpl::BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) {
|
||||
dim = cfg["dim"].as<int>();
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
server::KnowhereError VecIndexImpl::BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) {
|
||||
try {
|
||||
dim = cfg["dim"].as<int>();
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
|
||||
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
|
||||
index_->set_preprocessor(preprocessor);
|
||||
auto nlist = int(nb / 1000000.0 * 16384);
|
||||
auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
|
||||
auto model = index_->Train(dataset, cfg_t);
|
||||
index_->set_index_model(model);
|
||||
index_->Add(dataset, cfg);
|
||||
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
|
||||
index_->set_preprocessor(preprocessor);
|
||||
auto nlist = int(nb / 1000000.0 * 16384);
|
||||
auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
|
||||
auto model = index_->Train(dataset, cfg_t);
|
||||
index_->set_index_model(model);
|
||||
index_->Add(dataset, cfg);
|
||||
} catch (KnowhereException &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (jsoncons::json_exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_INVALID_ARGUMENT;
|
||||
} catch (std::exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_ERROR;
|
||||
}
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
void VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
|
||||
// TODO(linxj): Assert index is trained;
|
||||
server::KnowhereError VecIndexImpl::Add(const long &nb, const float *xb, const long *ids, const Config &cfg) {
|
||||
try {
|
||||
auto d = cfg.get_with_default("dim", dim);
|
||||
auto dataset = GenDatasetWithIds(nb, d, xb, ids);
|
||||
|
||||
auto d = cfg.get_with_default("dim", dim);
|
||||
auto dataset = GenDatasetWithIds(nb, d, xb, ids);
|
||||
|
||||
index_->Add(dataset, cfg);
|
||||
index_->Add(dataset, cfg);
|
||||
} catch (KnowhereException &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (jsoncons::json_exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_INVALID_ARGUMENT;
|
||||
} catch (std::exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_ERROR;
|
||||
}
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
void VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
|
||||
// TODO: Assert index is trained;
|
||||
server::KnowhereError VecIndexImpl::Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) {
|
||||
try {
|
||||
auto k = cfg["k"].as<int>();
|
||||
auto d = cfg.get_with_default("dim", dim);
|
||||
auto dataset = GenDataset(nq, d, xq);
|
||||
|
||||
auto k = cfg["k"].as<int>();
|
||||
auto d = cfg.get_with_default("dim", dim);
|
||||
auto dataset = GenDataset(nq, d, xq);
|
||||
Config search_cfg;
|
||||
auto res = index_->Search(dataset, cfg);
|
||||
auto ids_array = res->array()[0];
|
||||
auto dis_array = res->array()[1];
|
||||
|
||||
Config search_cfg;
|
||||
auto res = index_->Search(dataset, cfg);
|
||||
auto ids_array = res->array()[0];
|
||||
auto dis_array = res->array()[1];
|
||||
//{
|
||||
// auto& ids = ids_array;
|
||||
// auto& dists = dis_array;
|
||||
// std::stringstream ss_id;
|
||||
// std::stringstream ss_dist;
|
||||
// for (auto i = 0; i < 10; i++) {
|
||||
// for (auto j = 0; j < k; ++j) {
|
||||
// ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
|
||||
// ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
|
||||
// }
|
||||
// ss_id << std::endl;
|
||||
// ss_dist << std::endl;
|
||||
// }
|
||||
// std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
// std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
//}
|
||||
|
||||
//{
|
||||
// auto& ids = ids_array;
|
||||
// auto& dists = dis_array;
|
||||
// std::stringstream ss_id;
|
||||
// std::stringstream ss_dist;
|
||||
// for (auto i = 0; i < 10; i++) {
|
||||
// for (auto j = 0; j < k; ++j) {
|
||||
// ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
|
||||
// ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
|
||||
// }
|
||||
// ss_id << std::endl;
|
||||
// ss_dist << std::endl;
|
||||
// }
|
||||
// std::cout << "id\n" << ss_id.str() << std::endl;
|
||||
// std::cout << "dist\n" << ss_dist.str() << std::endl;
|
||||
//}
|
||||
auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
|
||||
auto p_dist = dis_array->data()->GetValues<float>(1, 0);
|
||||
|
||||
auto p_ids = ids_array->data()->GetValues<int64_t>(1, 0);
|
||||
auto p_dist = dis_array->data()->GetValues<float>(1, 0);
|
||||
|
||||
// TODO(linxj): avoid copy here.
|
||||
memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
|
||||
memcpy(dist, p_dist, sizeof(float) * nq * k);
|
||||
// TODO(linxj): avoid copy here.
|
||||
memcpy(ids, p_ids, sizeof(int64_t) * nq * k);
|
||||
memcpy(dist, p_dist, sizeof(float) * nq * k);
|
||||
} catch (KnowhereException &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (jsoncons::json_exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_INVALID_ARGUMENT;
|
||||
} catch (std::exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_ERROR;
|
||||
}
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
zilliz::knowhere::BinarySet VecIndexImpl::Serialize() {
|
||||
return index_->Serialize();
|
||||
}
|
||||
|
||||
void VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
|
||||
server::KnowhereError VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) {
|
||||
index_->Load(index_binary);
|
||||
dim = Dimension();
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
int64_t VecIndexImpl::Dimension() {
|
||||
@ -114,56 +148,91 @@ int64_t *BFIndex::GetRawIds() {
|
||||
return std::static_pointer_cast<IDMAP>(index_)->GetRawIds();
|
||||
}
|
||||
|
||||
void BFIndex::Build(const int64_t &d) {
|
||||
dim = d;
|
||||
std::static_pointer_cast<IDMAP>(index_)->Train(dim);
|
||||
server::KnowhereError BFIndex::Build(const int64_t &d) {
|
||||
try {
|
||||
dim = d;
|
||||
std::static_pointer_cast<IDMAP>(index_)->Train(dim);
|
||||
} catch (KnowhereException &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (jsoncons::json_exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_INVALID_ARGUMENT;
|
||||
} catch (std::exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_ERROR;
|
||||
}
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
void BFIndex::BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) {
|
||||
dim = cfg["dim"].as<int>();
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
server::KnowhereError BFIndex::BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) {
|
||||
try {
|
||||
dim = cfg["dim"].as<int>();
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
|
||||
std::static_pointer_cast<IDMAP>(index_)->Train(dim);
|
||||
index_->Add(dataset, cfg);
|
||||
std::static_pointer_cast<IDMAP>(index_)->Train(dim);
|
||||
index_->Add(dataset, cfg);
|
||||
} catch (KnowhereException &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (jsoncons::json_exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_INVALID_ARGUMENT;
|
||||
} catch (std::exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_ERROR;
|
||||
}
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
// TODO(linxj): add lock here.
|
||||
void IVFMixIndex::BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) {
|
||||
WRAPPER_LOG_DEBUG << "Get Into Build IVFMIX";
|
||||
server::KnowhereError IVFMixIndex::BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) {
|
||||
try {
|
||||
dim = cfg["dim"].as<int>();
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
|
||||
dim = cfg["dim"].as<int>();
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
|
||||
index_->set_preprocessor(preprocessor);
|
||||
auto nlist = int(nb / 1000000.0 * 16384);
|
||||
auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
|
||||
auto model = index_->Train(dataset, cfg_t);
|
||||
index_->set_index_model(model);
|
||||
index_->Add(dataset, cfg);
|
||||
|
||||
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
|
||||
index_->set_preprocessor(preprocessor);
|
||||
auto nlist = int(nb / 1000000.0 * 16384);
|
||||
auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
|
||||
auto model = index_->Train(dataset, cfg_t);
|
||||
index_->set_index_model(model);
|
||||
index_->Add(dataset, cfg);
|
||||
|
||||
if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
|
||||
auto host_index = device_index->Copy_index_gpu_to_cpu();
|
||||
index_ = host_index;
|
||||
} else {
|
||||
WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
|
||||
if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
|
||||
auto host_index = device_index->Copy_index_gpu_to_cpu();
|
||||
index_ = host_index;
|
||||
} else {
|
||||
WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed";
|
||||
}
|
||||
} catch (KnowhereException &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (jsoncons::json_exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_INVALID_ARGUMENT;
|
||||
} catch (std::exception &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_ERROR;
|
||||
}
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
void IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
|
||||
server::KnowhereError IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
|
||||
index_ = std::make_shared<IVF>();
|
||||
index_->Load(index_binary);
|
||||
dim = Dimension();
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -19,19 +19,19 @@ class VecIndexImpl : public VecIndex {
|
||||
public:
|
||||
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
|
||||
: index_(std::move(index)), type(type) {};
|
||||
void BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
server::KnowhereError BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
IndexType GetType() override;
|
||||
int64_t Dimension() override;
|
||||
int64_t Count() override;
|
||||
void Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
|
||||
server::KnowhereError Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
|
||||
zilliz::knowhere::BinarySet Serialize() override;
|
||||
void Load(const zilliz::knowhere::BinarySet &index_binary) override;
|
||||
void Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;
|
||||
server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) override;
|
||||
server::KnowhereError Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override;
|
||||
|
||||
protected:
|
||||
int64_t dim = 0;
|
||||
@ -43,27 +43,27 @@ class IVFMixIndex : public VecIndexImpl {
|
||||
public:
|
||||
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
|
||||
IndexType::FAISS_IVFFLAT_MIX) {};
|
||||
void BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
void Load(const zilliz::knowhere::BinarySet &index_binary) override;
|
||||
server::KnowhereError BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) override;
|
||||
};
|
||||
|
||||
class BFIndex : public VecIndexImpl {
|
||||
public:
|
||||
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
|
||||
IndexType::FAISS_IDMAP) {};
|
||||
void Build(const int64_t &d);
|
||||
server::KnowhereError Build(const int64_t &d);
|
||||
float *GetRawVectors();
|
||||
void BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
server::KnowhereError BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
int64_t *GetRawIds();
|
||||
};
|
||||
|
||||
|
||||
@ -7,9 +7,11 @@
|
||||
#include "knowhere/index/vector_index/idmap.h"
|
||||
#include "knowhere/index/vector_index/gpu_ivf.h"
|
||||
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
|
||||
#include "knowhere/common/exception.h"
|
||||
|
||||
#include "vec_index.h"
|
||||
#include "vec_impl.h"
|
||||
#include "wrapper_log.h"
|
||||
|
||||
|
||||
namespace zilliz {
|
||||
@ -153,23 +155,32 @@ VecIndexPtr read_index(const std::string &location) {
|
||||
return LoadVecIndex(current_type, load_data_list);
|
||||
}
|
||||
|
||||
void write_index(VecIndexPtr index, const std::string &location) {
|
||||
auto binaryset = index->Serialize();
|
||||
auto index_type = index->GetType();
|
||||
server::KnowhereError write_index(VecIndexPtr index, const std::string &location) {
|
||||
try {
|
||||
auto binaryset = index->Serialize();
|
||||
auto index_type = index->GetType();
|
||||
|
||||
FileIOWriter writer(location);
|
||||
writer(&index_type, sizeof(IndexType));
|
||||
for (auto &iter: binaryset.binary_map_) {
|
||||
auto meta = iter.first.c_str();
|
||||
size_t meta_length = iter.first.length();
|
||||
writer(&meta_length, sizeof(meta_length));
|
||||
writer((void *) meta, meta_length);
|
||||
FileIOWriter writer(location);
|
||||
writer(&index_type, sizeof(IndexType));
|
||||
for (auto &iter: binaryset.binary_map_) {
|
||||
auto meta = iter.first.c_str();
|
||||
size_t meta_length = iter.first.length();
|
||||
writer(&meta_length, sizeof(meta_length));
|
||||
writer((void *) meta, meta_length);
|
||||
|
||||
auto binary = iter.second;
|
||||
int64_t binary_length = binary->size;
|
||||
writer(&binary_length, sizeof(binary_length));
|
||||
writer((void *) binary->data.get(), binary_length);
|
||||
auto binary = iter.second;
|
||||
int64_t binary_length = binary->size;
|
||||
writer(&binary_length, sizeof(binary_length));
|
||||
writer((void *) binary->data.get(), binary_length);
|
||||
}
|
||||
} catch (knowhere::KnowhereException &e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_UNEXPECTED_ERROR;
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return server::KNOWHERE_ERROR;
|
||||
}
|
||||
return server::KNOWHERE_SUCCESS;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -9,6 +9,8 @@
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/Error.h"
|
||||
|
||||
#include "knowhere/common/config.h"
|
||||
#include "knowhere/common/binary_set.h"
|
||||
|
||||
@ -34,23 +36,23 @@ enum class IndexType {
|
||||
|
||||
class VecIndex {
|
||||
public:
|
||||
virtual void BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt = 0,
|
||||
const float *xt = nullptr) = 0;
|
||||
virtual server::KnowhereError BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt = 0,
|
||||
const float *xt = nullptr) = 0;
|
||||
|
||||
virtual void Add(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg = Config()) = 0;
|
||||
virtual server::KnowhereError Add(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg = Config()) = 0;
|
||||
|
||||
virtual void Search(const long &nq,
|
||||
const float *xq,
|
||||
float *dist,
|
||||
long *ids,
|
||||
const Config &cfg = Config()) = 0;
|
||||
virtual server::KnowhereError Search(const long &nq,
|
||||
const float *xq,
|
||||
float *dist,
|
||||
long *ids,
|
||||
const Config &cfg = Config()) = 0;
|
||||
|
||||
virtual IndexType GetType() = 0;
|
||||
|
||||
@ -60,12 +62,12 @@ class VecIndex {
|
||||
|
||||
virtual zilliz::knowhere::BinarySet Serialize() = 0;
|
||||
|
||||
virtual void Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
|
||||
virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0;
|
||||
};
|
||||
|
||||
using VecIndexPtr = std::shared_ptr<VecIndex>;
|
||||
|
||||
extern void write_index(VecIndexPtr index, const std::string &location);
|
||||
extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location);
|
||||
|
||||
extern VecIndexPtr read_index(const std::string &location);
|
||||
|
||||
|
||||
2
cpp/thirdparty/knowhere
vendored
2
cpp/thirdparty/knowhere
vendored
@ -1 +1 @@
|
||||
Subproject commit afaf65282737514e232bf477aacb2772a4d32d5d
|
||||
Subproject commit b0b9dd18fadbf9dc0fccaad815e14e578a92993e
|
||||
@ -163,3 +163,7 @@ TEST_P(KnowhereWrapperTest, serialize) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(linxj): add exception test
|
||||
//TEST_P(KnowhereWrapperTest, exception_test) {
|
||||
//}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user