diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 1dfafca28a..646d6e79f2 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -113,7 +113,7 @@ Status DBImpl::search(const std::string& group_id, size_t k, size_t nq, auto search_in_index = [&](meta::GroupFilesSchema& file_vec) -> void { for (auto &file : file_vec) { - FaissExecutionEngineBase index(file.dimension, file.location); + FaissExecutionEngine index(file.dimension, file.location); index.Load(); auto file_size = index.PhysicalSize()/(1024*1024); search_set_size += file_size; @@ -213,7 +213,7 @@ Status DBImpl::merge_files(const std::string& group_id, const meta::DateT& date, return status; } - FaissExecutionEngineBase index(group_file.dimension, group_file.location); + FaissExecutionEngine index(group_file.dimension, group_file.location); meta::GroupFilesSchema updated; long index_size = 0; @@ -286,7 +286,7 @@ Status DBImpl::build_index(const meta::GroupFileSchema& file) { return status; } - FaissExecutionEngineBase to_index(file.dimension, file.location); + FaissExecutionEngine to_index(file.dimension, file.location); to_index.Load(); auto index = to_index.BuildIndex(group_file.location); diff --git a/cpp/src/db/ExecutionEngine.cpp b/cpp/src/db/ExecutionEngine.cpp index ffb5450486..26e35aff1c 100644 --- a/cpp/src/db/ExecutionEngine.cpp +++ b/cpp/src/db/ExecutionEngine.cpp @@ -5,7 +5,8 @@ namespace zilliz { namespace vecwise { namespace engine { -Status ExecutionEngine::AddWithIds(const std::vector& vectors, const std::vector& vector_ids) { +template +Status ExecutionEngine::AddWithIds(const std::vector& vectors, const std::vector& vector_ids) { long n1 = (long)vectors.size(); long n2 = (long)vector_ids.size(); if (n1 != n2) { @@ -16,53 +17,42 @@ Status ExecutionEngine::AddWithIds(const std::vector& vectors, const std: } template -Status ExecutionEngineBase::AddWithIds(const std::vector& vectors, const std::vector& vector_ids) { - long n1 = (long)vectors.size(); - long n2 = (long)vector_ids.size(); - if (n1 != n2) { - LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2; - return Status::Error("Error: AddWithIds"); - } - return AddWithIds(n1, vectors.data(), vector_ids.data()); -} - -template -Status ExecutionEngineBase::AddWithIds(long n, const float *xdata, const long *xids) { +Status ExecutionEngine::AddWithIds(long n, const float *xdata, const long *xids) { return static_cast(this)->AddWithIds(n, xdata, xids); } template -size_t ExecutionEngineBase::Count() const { +size_t ExecutionEngine::Count() const { return static_cast(this)->Count(); } template -size_t ExecutionEngineBase::Size() const { +size_t ExecutionEngine::Size() const { return static_cast(this)->Size(); } template -size_t ExecutionEngineBase::PhysicalSize() const { +size_t ExecutionEngine::PhysicalSize() const { return static_cast(this)->PhysicalSize(); } template -Status ExecutionEngineBase::Serialize() { +Status ExecutionEngine::Serialize() { return static_cast(this)->Serialize(); } template -Status ExecutionEngineBase::Load() { +Status ExecutionEngine::Load() { return static_cast(this)->Load(); } template -Status ExecutionEngineBase::Merge(const std::string& location) { +Status ExecutionEngine::Merge(const std::string& location) { return static_cast(this)->Merge(location); } template -Status ExecutionEngineBase::Search(long n, +Status ExecutionEngine::Search(long n, const float *data, long k, float *distances, @@ -71,12 +61,12 @@ Status ExecutionEngineBase::Search(long n, } template -Status ExecutionEngineBase::Cache() { +Status ExecutionEngine::Cache() { return static_cast(this)->Cache(); } template -std::shared_ptr ExecutionEngineBase::BuildIndex(const std::string& location) { +std::shared_ptr ExecutionEngine::BuildIndex(const std::string& location) { return static_cast(this)->BuildIndex(location); } diff --git a/cpp/src/db/ExecutionEngine.h b/cpp/src/db/ExecutionEngine.h index 30b5f6ea11..2989e5b42e 100644 --- a/cpp/src/db/ExecutionEngine.h +++ b/cpp/src/db/ExecutionEngine.h @@ -9,43 +9,8 @@ namespace zilliz { namespace vecwise { namespace engine { -class ExecutionEngine; - -class ExecutionEngine { -public: - - Status AddWithIds(const std::vector& vectors, - const std::vector& vector_ids); - - virtual Status AddWithIds(long n, const float *xdata, const long *xids) = 0; - - virtual size_t Count() const = 0; - - virtual size_t Size() const = 0; - - virtual size_t PhysicalSize() const = 0; - - virtual Status Serialize() = 0; - - virtual Status Load() = 0; - - virtual Status Merge(const std::string& location) = 0; - - virtual Status Search(long n, - const float *data, - long k, - float *distances, - long *labels) const = 0; - - virtual std::shared_ptr BuildIndex(const std::string&) = 0; - - virtual Status Cache() = 0; - - virtual ~ExecutionEngine() {} -}; - template -class ExecutionEngineBase { +class ExecutionEngine { public: Status AddWithIds(const std::vector& vectors, diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp index ab836d6835..6f275f58f6 100644 --- a/cpp/src/db/FaissExecutionEngine.cpp +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -16,6 +16,7 @@ namespace engine { const std::string RawIndexType = "IDMap,Flat"; const std::string BuildIndexType = "IDMap,Flat"; + FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension, const std::string& location) : pIndex_(faiss::index_factory(dimension, RawIndexType.c_str())), location_(location) { @@ -74,7 +75,7 @@ Status FaissExecutionEngine::Merge(const std::string& location) { return Status::OK(); } -std::shared_ptr FaissExecutionEngine::BuildIndex(const std::string& location) { +std::shared_ptr FaissExecutionEngine::BuildIndex(const std::string& location) { auto opd = std::make_shared(); opd->d = pIndex_->d; opd->index_type = BuildIndexType; @@ -86,7 +87,7 @@ std::shared_ptr FaissExecutionEngine::BuildIndex(const std::str dynamic_cast(from_index->index)->xb.data(), from_index->id_map.data()); - std::shared_ptr new_ee(new FaissExecutionEngine(index->data(), location)); + std::shared_ptr new_ee(new FaissExecutionEngine(index->data(), location)); new_ee->Serialize(); return new_ee; } @@ -109,99 +110,6 @@ Status FaissExecutionEngine::Cache() { } -FaissExecutionEngineBase::FaissExecutionEngineBase(uint16_t dimension, const std::string& location) - : pIndex_(faiss::index_factory(dimension, RawIndexType.c_str())), - location_(location) { -} - -FaissExecutionEngineBase::FaissExecutionEngineBase(std::shared_ptr index, const std::string& location) - : pIndex_(index), - location_(location) { -} - -Status FaissExecutionEngineBase::AddWithIds(long n, const float *xdata, const long *xids) { - pIndex_->add_with_ids(n, xdata, xids); - return Status::OK(); -} - -size_t FaissExecutionEngineBase::Count() const { - return (size_t)(pIndex_->ntotal); -} - -size_t FaissExecutionEngineBase::Size() const { - return (size_t)(Count() * pIndex_->d); -} - -size_t FaissExecutionEngineBase::PhysicalSize() const { - return (size_t)(Size()*sizeof(float)); -} - -Status FaissExecutionEngineBase::Serialize() { - write_index(pIndex_.get(), location_.c_str()); - return Status::OK(); -} - -Status FaissExecutionEngineBase::Load() { - auto index = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location_); - if (!index) { - index = read_index(location_); - Cache(); - LOG(DEBUG) << "Disk io from: " << location_; - } - - pIndex_ = index->data(); - return Status::OK(); -} - -Status FaissExecutionEngineBase::Merge(const std::string& location) { - if (location == location_) { - return Status::Error("Cannot Merge Self"); - } - auto to_merge = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location); - if (!to_merge) { - to_merge = read_index(location); - } - auto file_index = dynamic_cast(to_merge->data().get()); - pIndex_->add_with_ids(file_index->ntotal, dynamic_cast(file_index->index)->xb.data(), - file_index->id_map.data()); - return Status::OK(); -} - -std::shared_ptr FaissExecutionEngineBase::BuildIndex(const std::string& location) { - auto opd = std::make_shared(); - opd->d = pIndex_->d; - opd->index_type = BuildIndexType; - IndexBuilderPtr pBuilder = GetIndexBuilder(opd); - - auto from_index = dynamic_cast(pIndex_.get()); - - auto index = pBuilder->build_all(from_index->ntotal, - dynamic_cast(from_index->index)->xb.data(), - from_index->id_map.data()); - - std::shared_ptr new_ee(new FaissExecutionEngineBase(index->data(), location)); - new_ee->Serialize(); - return new_ee; -} - -Status FaissExecutionEngineBase::Search(long n, - const float *data, - long k, - float *distances, - long *labels) const { - - pIndex_->search(n, data, k, distances, labels); - return Status::OK(); -} - -Status FaissExecutionEngineBase::Cache() { - zilliz::vecwise::cache::CpuCacheMgr::GetInstance( - )->InsertItem(location_, std::make_shared(pIndex_)); - - return Status::OK(); -} - - } // namespace engine } // namespace vecwise } // namespace zilliz diff --git a/cpp/src/db/FaissExecutionEngine.h b/cpp/src/db/FaissExecutionEngine.h index 925b2685c9..e2d007ffc7 100644 --- a/cpp/src/db/FaissExecutionEngine.h +++ b/cpp/src/db/FaissExecutionEngine.h @@ -13,44 +13,12 @@ namespace zilliz { namespace vecwise { namespace engine { -class FaissExecutionEngine : public ExecutionEngine { + +class FaissExecutionEngine : public ExecutionEngine { public: FaissExecutionEngine(uint16_t dimension, const std::string& location); FaissExecutionEngine(std::shared_ptr index, const std::string& location); - virtual Status AddWithIds(long n, const float *xdata, const long *xids) override; - - virtual size_t Count() const override; - - virtual size_t Size() const override; - - virtual size_t PhysicalSize() const override; - - virtual Status Merge(const std::string& location) override; - - virtual Status Serialize() override; - virtual Status Load() override; - - virtual Status Cache() override; - - virtual Status Search(long n, - const float *data, - long k, - float *distances, - long *labels) const override; - - virtual std::shared_ptr BuildIndex(const std::string&) override; - -protected: - std::shared_ptr pIndex_; - std::string location_; -}; - -class FaissExecutionEngineBase : public ExecutionEngineBase { -public: - FaissExecutionEngineBase(uint16_t dimension, const std::string& location); - FaissExecutionEngineBase(std::shared_ptr index, const std::string& location); - Status AddWithIds(const std::vector& vectors, const std::vector& vector_ids); @@ -74,7 +42,7 @@ public: float *distances, long *labels) const; - std::shared_ptr BuildIndex(const std::string&); + std::shared_ptr BuildIndex(const std::string&); Status Cache(); protected: diff --git a/cpp/src/db/MemManager.cpp b/cpp/src/db/MemManager.cpp index fa2858e309..904c5db150 100644 --- a/cpp/src/db/MemManager.cpp +++ b/cpp/src/db/MemManager.cpp @@ -18,7 +18,7 @@ MemVectors::MemVectors(const std::shared_ptr& meta_ptr, options_(options), schema_(schema), _pIdGenerator(new SimpleIDGenerator()), - pEE_(new FaissExecutionEngineBase(schema_.dimension, schema_.location)) { + pEE_(new FaissExecutionEngine(schema_.dimension, schema_.location)) { } void MemVectors::add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) { diff --git a/cpp/src/db/MemManager.h b/cpp/src/db/MemManager.h index 374f75cd59..c1d8736407 100644 --- a/cpp/src/db/MemManager.h +++ b/cpp/src/db/MemManager.h @@ -19,7 +19,7 @@ namespace meta { class Meta; } -class FaissExecutionEngineBase; +class FaissExecutionEngine; class MemVectors { public: @@ -47,7 +47,7 @@ private: Options options_; meta::GroupFileSchema schema_; IDGenerator* _pIdGenerator; - std::shared_ptr pEE_; + std::shared_ptr pEE_; }; // MemVectors