milvus/cpp/src/wrapper/IndexBuilder.cpp
xj.lin 502f1a956c MS-27 support gpu config
Former-commit-id: 08749b66413000571d733a28303eed3944220a9b
2019-05-31 18:25:06 +08:00

142 lines
4.1 KiB
C++

////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include "mutex"
#ifdef GPU_VERSION
#include <faiss/gpu/StandardGpuResources.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuAutoTune.h>
#endif
#include <faiss/IndexFlat.h>
#include <easylogging++.h>
#include "server/ServerConfig.h"
#include "IndexBuilder.h"
namespace zilliz {
namespace vecwise {
namespace engine {
class GpuResources {
public:
static GpuResources &GetInstance() {
static GpuResources instance;
return instance;
}
void SelectGpu() {
using namespace zilliz::vecwise::server;
ServerConfig &config = ServerConfig::GetInstance();
ConfigNode server_config = config.GetConfig(CONFIG_SERVER);
gpu_num = server_config.GetInt32Value("gpu_index", 0);
}
int32_t GetGpu() {
return gpu_num;
}
private:
GpuResources() : gpu_num(0) { SelectGpu(); }
private:
int32_t gpu_num;
};
using std::vector;
static std::mutex gpu_resource;
static std::mutex cpu_resource;
IndexBuilder::IndexBuilder(const Operand_ptr &opd) {
opd_ = opd;
}
// Default: build use gpu
Index_ptr IndexBuilder::build_all(const long &nb,
const float *xb,
const long *ids,
const long &nt,
const float *xt) {
std::shared_ptr<faiss::Index> host_index = nullptr;
#ifdef GPU_VERSION
{
// TODO: list support index-type.
faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
std::lock_guard<std::mutex> lk(gpu_resource);
faiss::gpu::StandardGpuResources res;
auto device_index = faiss::gpu::index_cpu_to_gpu(&res, GpuResources::GetInstance().GetGpu(), ori_index);
if (!device_index->is_trained) {
nt == 0 || xt == nullptr ? device_index->train(nb, xb)
: device_index->train(nt, xt);
}
device_index->add_with_ids(nb, xb, ids); // TODO: support with add_with_IDMAP
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
delete device_index;
delete ori_index;
}
#else
{
faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str());
if (!index->is_trained) {
nt == 0 || xt == nullptr ? index->train(nb, xb)
: index->train(nt, xt);
}
index->add_with_ids(nb, xb, ids);
host_index.reset(index);
}
#endif
return std::make_shared<Index>(host_index);
}
Index_ptr IndexBuilder::build_all(const long &nb, const vector<float> &xb,
const vector<long> &ids,
const long &nt, const vector<float> &xt) {
return build_all(nb, xb.data(), ids.data(), nt, xt.data());
}
BgCpuBuilder::BgCpuBuilder(const zilliz::vecwise::engine::Operand_ptr &opd) : IndexBuilder(opd) {};
Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) {
std::shared_ptr<faiss::Index> index = nullptr;
index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()));
{
std::lock_guard<std::mutex> lk(cpu_resource);
if (!index->is_trained) {
nt == 0 || xt == nullptr ? index->train(nb, xb)
: index->train(nt, xt);
}
index->add_with_ids(nb, xb, ids);
}
return std::make_shared<Index>(index);
}
// TODO: Be Factory pattern later
IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd) {
if (opd->index_type == "IDMap") {
// TODO: fix hardcode
IndexBuilderPtr index = nullptr;
return std::make_shared<BgCpuBuilder>(opd);
}
return std::make_shared<IndexBuilder>(opd);
}
}
}
}