//////////////////////////////////////////////////////////////////////////////// // Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved // Unauthorized copying of this file, via any medium is strictly prohibited. // Proprietary and confidential. //////////////////////////////////////////////////////////////////////////////// #include "mutex" #ifdef GPU_VERSION #include #include #include #endif #include #include #include "server/ServerConfig.h" #include "IndexBuilder.h" namespace zilliz { namespace vecwise { namespace engine { class GpuResources { public: static GpuResources &GetInstance() { static GpuResources instance; return instance; } void SelectGpu() { using namespace zilliz::vecwise::server; ServerConfig &config = ServerConfig::GetInstance(); ConfigNode server_config = config.GetConfig(CONFIG_SERVER); gpu_num = server_config.GetInt32Value("gpu_index", 0); } int32_t GetGpu() { return gpu_num; } private: GpuResources() : gpu_num(0) { SelectGpu(); } private: int32_t gpu_num; }; using std::vector; static std::mutex gpu_resource; static std::mutex cpu_resource; IndexBuilder::IndexBuilder(const Operand_ptr &opd) { opd_ = opd; } // Default: build use gpu Index_ptr IndexBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) { std::shared_ptr host_index = nullptr; #ifdef GPU_VERSION { // TODO: list support index-type. faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()); std::lock_guard lk(gpu_resource); faiss::gpu::StandardGpuResources res; auto device_index = faiss::gpu::index_cpu_to_gpu(&res, GpuResources::GetInstance().GetGpu(), ori_index); if (!device_index->is_trained) { nt == 0 || xt == nullptr ? device_index->train(nb, xb) : device_index->train(nt, xt); } device_index->add_with_ids(nb, xb, ids); // TODO: support with add_with_IDMAP host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index)); delete device_index; delete ori_index; } #else { faiss::Index *index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()); if (!index->is_trained) { nt == 0 || xt == nullptr ? index->train(nb, xb) : index->train(nt, xt); } index->add_with_ids(nb, xb, ids); host_index.reset(index); } #endif return std::make_shared(host_index); } Index_ptr IndexBuilder::build_all(const long &nb, const vector &xb, const vector &ids, const long &nt, const vector &xt) { return build_all(nb, xb.data(), ids.data(), nt, xt.data()); } BgCpuBuilder::BgCpuBuilder(const zilliz::vecwise::engine::Operand_ptr &opd) : IndexBuilder(opd) {}; Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) { std::shared_ptr index = nullptr; index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str())); { std::lock_guard lk(cpu_resource); if (!index->is_trained) { nt == 0 || xt == nullptr ? index->train(nb, xb) : index->train(nt, xt); } index->add_with_ids(nb, xb, ids); } return std::make_shared(index); } // TODO: Be Factory pattern later IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd) { if (opd->index_type == "IDMap") { // TODO: fix hardcode IndexBuilderPtr index = nullptr; return std::make_shared(opd); } return std::make_shared(opd); } } } }