diff --git a/cpp/src/db/MemManager.cpp b/cpp/src/db/MemManager.cpp index 31d1ed0db3..3f74301f47 100644 --- a/cpp/src/db/MemManager.cpp +++ b/cpp/src/db/MemManager.cpp @@ -1,5 +1,6 @@ -#include -#include +/* #include */ +/* #include */ +#include #include #include #include @@ -19,20 +20,19 @@ MemVectors::MemVectors(const std::string& group_id, _file_location(file_location), _pIdGenerator(new SimpleIDGenerator()), _dimension(dimension), - _pInnerIndex(new faiss::IndexFlat(_dimension)), - _pIdMapIndex(new faiss::IndexIDMap(_pInnerIndex)) { + pIndex_(faiss::index_factory(_dimension, "IDMap,Flat")) { } void MemVectors::add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) { _pIdGenerator->getNextIDNumbers(n_, vector_ids_); - _pIdMapIndex->add_with_ids(n_, vectors_, &vector_ids_[0]); + pIndex_->add_with_ids(n_, vectors_, &vector_ids_[0]); for(auto i=0 ; intotal; + return pIndex_->ntotal; } size_t MemVectors::approximate_size() const { @@ -42,10 +42,10 @@ size_t MemVectors::approximate_size() const { Status MemVectors::serialize(std::string& group_id) { /* std::stringstream ss; */ /* ss << "/tmp/test/" << _pIdGenerator->getNextIDNumber(); */ - /* faiss::write_index(_pIdMapIndex, ss.str().c_str()); */ - /* std::cout << _pIdMapIndex->ntotal << std::endl; */ + /* faiss::write_index(pIndex_, ss.str().c_str()); */ + /* std::cout << pIndex_->ntotal << std::endl; */ /* std::cout << _file_location << std::endl; */ - faiss::write_index(_pIdMapIndex, _file_location.c_str()); + faiss::write_index(pIndex_, _file_location.c_str()); group_id = group_id_; return Status::OK(); } @@ -55,13 +55,9 @@ MemVectors::~MemVectors() { delete _pIdGenerator; _pIdGenerator = nullptr; } - if (_pIdMapIndex != nullptr) { - delete _pIdMapIndex; - _pIdMapIndex = nullptr; - } - if (_pInnerIndex != nullptr) { - delete _pInnerIndex; - _pInnerIndex = nullptr; + if (pIndex_ != nullptr) { + delete pIndex_; + pIndex_ = nullptr; } } diff --git a/cpp/src/db/MemManager.h b/cpp/src/db/MemManager.h index 86b3973d62..48aacc4fb6 100644 --- a/cpp/src/db/MemManager.h +++ b/cpp/src/db/MemManager.h @@ -10,7 +10,6 @@ #include "Status.h" namespace faiss { - class IndexIDMap; class Index; } @@ -50,8 +49,7 @@ private: const std::string _file_location; IDGenerator* _pIdGenerator; size_t _dimension; - faiss::Index* _pInnerIndex; - faiss::IndexIDMap* _pIdMapIndex; + faiss::Index* pIndex_; }; // MemVectors