From 2ac87c1e473052cff8beb8799b23aa0a8a7898c1 Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Sun, 5 May 2019 20:50:08 +0800 Subject: [PATCH] 1. support IDMap 2. fix some bug 3. background job from IDMap -> IVF Former-commit-id: ba8f24f09c5481103ad3f4c1c91d4deb70f26dad --- cpp/src/db/FaissExecutionEngine.cpp | 2 +- cpp/src/wrapper/IndexBuilder.cpp | 38 +++++++++++++++--- cpp/src/wrapper/IndexBuilder.h | 36 +++++++++++------ cpp/src/wrapper/Operand.cpp | 43 +++++++++++++++++++++ cpp/src/wrapper/Operand.h | 12 ++++-- cpp/unittest/faiss_wrapper/wrapper_test.cpp | 14 +++---- 6 files changed, 115 insertions(+), 30 deletions(-) diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp index 6c86d2fbbd..06b7127217 100644 --- a/cpp/src/db/FaissExecutionEngine.cpp +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -19,7 +19,7 @@ namespace vecwise { namespace engine { const std::string RawIndexType = "IDMap,Flat"; -const std::string BuildIndexType = "IDMap,Flat"; +const std::string BuildIndexType = "IVF"; // IDMap / IVF FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension, const std::string& location) diff --git a/cpp/src/wrapper/IndexBuilder.cpp b/cpp/src/wrapper/IndexBuilder.cpp index 6d98106bba..e9552b984d 100644 --- a/cpp/src/wrapper/IndexBuilder.cpp +++ b/cpp/src/wrapper/IndexBuilder.cpp @@ -9,6 +9,7 @@ #include #include "faiss/gpu/GpuIndexIVFFlat.h" #include "faiss/gpu/GpuAutoTune.h" +#include "faiss/IndexFlat.h" #include "IndexBuilder.h" @@ -20,6 +21,7 @@ namespace engine { using std::vector; static std::mutex gpu_resource; +static std::mutex cpu_resource; IndexBuilder::IndexBuilder(const Operand_ptr &opd) { opd_ = opd; @@ -27,14 +29,14 @@ IndexBuilder::IndexBuilder(const Operand_ptr &opd) { // Default: build use gpu Index_ptr IndexBuilder::build_all(const long &nb, - const float* xb, - const long* ids, + const float *xb, + const long *ids, const long &nt, - const float* xt) { + const float *xt) { std::shared_ptr host_index = nullptr; { // TODO: list support index-type. - faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->index_type.c_str()); + faiss::Index *ori_index = faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str()); std::lock_guard lk(gpu_resource); faiss::gpu::StandardGpuResources res; @@ -43,7 +45,7 @@ Index_ptr IndexBuilder::build_all(const long &nb, nt == 0 || xt == nullptr ? device_index->train(nb, xb) : device_index->train(nt, xt); } - device_index->add_with_ids(nb, xb, ids); + device_index->add_with_ids(nb, xb, ids); // TODO: support with add_with_IDMAP host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index)); @@ -60,8 +62,32 @@ Index_ptr IndexBuilder::build_all(const long &nb, const vector &xb, return build_all(nb, xb.data(), ids.data(), nt, xt.data()); } -// Be Factory pattern later +BgCpuBuilder::BgCpuBuilder(const zilliz::vecwise::engine::Operand_ptr &opd) : IndexBuilder(opd) {}; + +Index_ptr BgCpuBuilder::build_all(const long &nb, const float *xb, const long *ids, const long &nt, const float *xt) { + std::shared_ptr index = nullptr; + index.reset(faiss::index_factory(opd_->d, opd_->get_index_type(nb).c_str())); + + { + std::lock_guard lk(cpu_resource); + if (!index->is_trained) { + nt == 0 || xt == nullptr ? index->train(nb, xb) + : index->train(nt, xt); + } + index->add_with_ids(nb, xb, ids); + } + + return std::make_shared(index); +} + +// TODO: Be Factory pattern later IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd) { + if (opd->index_type == "IDMap") { + // TODO: fix hardcode + IndexBuilderPtr index = nullptr; + return std::make_shared(opd); + } + return std::make_shared(opd); } diff --git a/cpp/src/wrapper/IndexBuilder.h b/cpp/src/wrapper/IndexBuilder.h index e8acc89eac..e4819326b3 100644 --- a/cpp/src/wrapper/IndexBuilder.h +++ b/cpp/src/wrapper/IndexBuilder.h @@ -11,25 +11,26 @@ #include "Operand.h" #include "Index.h" + namespace zilliz { namespace vecwise { namespace engine { class IndexBuilder { -public: + public: explicit IndexBuilder(const Operand_ptr &opd); - Index_ptr build_all(const long &nb, - const float* xb, - const long* ids, - const long &nt = 0, - const float* xt = nullptr); + virtual Index_ptr build_all(const long &nb, + const float *xb, + const long *ids, + const long &nt = 0, + const float *xt = nullptr); - Index_ptr build_all(const long &nb, - const std::vector &xb, - const std::vector &ids, - const long &nt = 0, - const std::vector &xt = std::vector()); + virtual Index_ptr build_all(const long &nb, + const std::vector &xb, + const std::vector &ids, + const long &nt = 0, + const std::vector &xt = std::vector()); void train(const long &nt, const std::vector &xt); @@ -41,10 +42,21 @@ public: void set_build_option(const Operand_ptr &opd); -private: + protected: Operand_ptr opd_ = nullptr; }; +class BgCpuBuilder : public IndexBuilder { + public: + BgCpuBuilder(const Operand_ptr &opd); + + virtual Index_ptr build_all(const long &nb, + const float *xb, + const long *ids, + const long &nt = 0, + const float *xt = nullptr) override; +}; + using IndexBuilderPtr = std::shared_ptr; extern IndexBuilderPtr GetIndexBuilder(const Operand_ptr &opd); diff --git a/cpp/src/wrapper/Operand.cpp b/cpp/src/wrapper/Operand.cpp index 131f821b3c..e3c4155086 100644 --- a/cpp/src/wrapper/Operand.cpp +++ b/cpp/src/wrapper/Operand.cpp @@ -6,10 +6,53 @@ #include "Operand.h" + namespace zilliz { namespace vecwise { namespace engine { +using std::string; + +enum IndexType { + Invalid_Option = 0, + IVF = 1, + IDMAP = 2 +}; + +IndexType resolveIndexType(const string &index_type) { + if (index_type == "IVF") { return IndexType::IVF; } + if (index_type == "IDMap") { return IndexType::IDMAP; } + return IndexType::Invalid_Option; +} + +// nb at least 100 +string Operand::get_index_type(const int &nb) { + if (!index_str.empty()) { return index_str; } + + // TODO: support OPQ or ... + if (!preproc.empty()) { index_str += (preproc + ","); } + + switch (resolveIndexType(index_type)) { + case Invalid_Option: { + // TODO: add exception + break; + } + case IVF: { + index_str += (ncent != 0 ? index_type + std::to_string(ncent) : + index_type + std::to_string(int(nb / 1000000.0 * 16384))); + break; + } + case IDMAP: { + index_str += index_type; + break; + } + } + + // TODO: support PQ or ... + if (!postproc.empty()) { index_str += ("," + postproc); } + return index_str; +} + std::ostream &operator<<(std::ostream &os, const Operand &obj) { os << obj.d << " " << obj.index_type << " " diff --git a/cpp/src/wrapper/Operand.h b/cpp/src/wrapper/Operand.h index 047ca917bf..f20cb30894 100644 --- a/cpp/src/wrapper/Operand.h +++ b/cpp/src/wrapper/Operand.h @@ -11,6 +11,7 @@ #include #include + namespace zilliz { namespace vecwise { namespace engine { @@ -21,11 +22,14 @@ struct Operand { friend std::istream &operator>>(std::istream &is, Operand &obj); int d; - std::string index_type = "IVF13864,Flat"; - std::string metric_type = "L2"; //> L2 / Inner Product + std::string index_type = "IVF"; + std::string metric_type = "L2"; //> L2 / IP(Inner Product) std::string preproc; - std::string postproc; - int ncent; + std::string postproc = "Flat"; + std::string index_str; + int ncent = 0; + + std::string get_index_type(const int &nb); }; using Operand_ptr = std::shared_ptr; diff --git a/cpp/unittest/faiss_wrapper/wrapper_test.cpp b/cpp/unittest/faiss_wrapper/wrapper_test.cpp index 87a6729054..480dcad0d3 100644 --- a/cpp/unittest/faiss_wrapper/wrapper_test.cpp +++ b/cpp/unittest/faiss_wrapper/wrapper_test.cpp @@ -18,17 +18,17 @@ TEST(operand_test, Wrapper_Test) { using std::endl; auto opd = std::make_shared(); - opd->index_type = "IDMap,Flat"; - opd->preproc = "opq"; - opd->postproc = "pq"; + opd->index_type = "IVF"; + opd->preproc = "OPQ"; + opd->postproc = "PQ"; opd->metric_type = "L2"; - opd->ncent = 256; opd->d = 64; auto opd_str = operand_to_str(opd); auto new_opd = str_to_operand(opd_str); - assert(new_opd->index_type == opd->index_type); + // TODO: fix all place where using opd to build index. + assert(new_opd->get_index_type(10000) == opd->get_index_type(10000)); } TEST(build_test, Wrapper_Test) { @@ -56,7 +56,7 @@ TEST(build_test, Wrapper_Test) { //train the index auto opd = std::make_shared(); - opd->index_type = "IVF16,Flat"; + opd->index_type = "IVF"; opd->d = d; opd->ncent = ncentroids; IndexBuilderPtr index_builder_1 = GetIndexBuilder(opd); @@ -120,7 +120,7 @@ TEST(gpu_build_test, Wrapper_Test) { for (int i = 0; i < nb; ++i) { ids[i] = i; } auto opd = std::make_shared(); - opd->index_type = "IVF256,Flat"; + opd->index_type = "IVF"; opd->d = d; opd->ncent = 256;