From 052f7e2f11c7f39d1039a2118c7ed5f056c41f60 Mon Sep 17 00:00:00 2001 From: Xu Peng Date: Tue, 30 Apr 2019 15:43:08 +0800 Subject: [PATCH] feat(db): add more api for serializer Former-commit-id: d15d7dfecb9964ca2e3ba4e5b469137d1cc85057 --- cpp/src/db/FaissSerializer.cpp | 23 +++++++++++++++++++---- cpp/src/db/FaissSerializer.h | 11 +++++++++-- cpp/src/db/Serializer.cpp | 4 ++-- cpp/src/db/Serializer.h | 12 ++++++++++-- cpp/src/db/Status.h | 5 +++++ 5 files changed, 45 insertions(+), 10 deletions(-) diff --git a/cpp/src/db/FaissSerializer.cpp b/cpp/src/db/FaissSerializer.cpp index 6bc8b487a1..fee7750469 100644 --- a/cpp/src/db/FaissSerializer.cpp +++ b/cpp/src/db/FaissSerializer.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "FaissSerializer.h" @@ -9,13 +10,27 @@ namespace engine { const std::string IndexType = "IDMap,Flat"; -FaissSerializer::FaissSerializer(uint16_t dimension) - : pIndex_(faiss::index_factory(dimension, IndexType.c_str())) { +FaissSerializer::FaissSerializer(uint16_t dimension, const std::string& location) + : pIndex_(faiss::index_factory(dimension, IndexType.c_str())), + location_(location) { } -bool FaissSerializer::AddWithIds(long n, const float *xdata, const long *xids) { +Status FaissSerializer::AddWithIds(long n, const float *xdata, const long *xids) { pIndex_->add_with_ids(n, xdata, xids); - return true; + return Status::OK(); +} + +size_t FaissSerializer::Count() const { + return (size_t)(pIndex_->ntotal); +} + +size_t FaissSerializer::Size() const { + return (size_t)(Count() * pIndex_->d); +} + +Status FaissSerializer::Serialize() { + write_index(pIndex_.get(), location_.c_str()); + return Status::OK(); } diff --git a/cpp/src/db/FaissSerializer.h b/cpp/src/db/FaissSerializer.h index fa13dad0a1..a56779996e 100644 --- a/cpp/src/db/FaissSerializer.h +++ b/cpp/src/db/FaissSerializer.h @@ -15,11 +15,18 @@ namespace engine { class FaissSerializer : public Serializer { public: - FaissSerializer(uint16_t dimension); - virtual bool AddWithIds(long n, const float *xdata, const long *xids) override; + FaissSerializer(uint16_t dimension, const std::string& location); + virtual Status AddWithIds(long n, const float *xdata, const long *xids) override; + + virtual size_t Count() const override; + + virtual size_t Size() const override; + + virtual Status Serialize() override; protected: std::shared_ptr pIndex_; + std::string location_; }; diff --git a/cpp/src/db/Serializer.cpp b/cpp/src/db/Serializer.cpp index 5a60defd50..595cd4731c 100644 --- a/cpp/src/db/Serializer.cpp +++ b/cpp/src/db/Serializer.cpp @@ -5,12 +5,12 @@ namespace zilliz { namespace vecwise { namespace engine { -bool Serializer::AddWithIds(const std::vector& vectors, const std::vector& vector_ids) { +Status Serializer::AddWithIds(const std::vector& vectors, const std::vector& vector_ids) { long n1 = (long)vectors.size(); long n2 = (long)vector_ids.size(); if (n1 != n2) { LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2; - return false; + return Status::Error("Error: AddWithIds"); } return AddWithIds(n1, vectors.data(), vector_ids.data()); } diff --git a/cpp/src/db/Serializer.h b/cpp/src/db/Serializer.h index b7760fe9bc..cb2891be2e 100644 --- a/cpp/src/db/Serializer.h +++ b/cpp/src/db/Serializer.h @@ -2,6 +2,8 @@ #include +#include "Status.h" + namespace zilliz { namespace vecwise { namespace engine { @@ -9,10 +11,16 @@ namespace engine { class Serializer { public: - bool AddWithIds(const std::vector& vectors, + Status AddWithIds(const std::vector& vectors, const std::vector& vector_ids); - virtual bool AddWithIds(long n, const float *xdata, const long *xids) = 0; + virtual Status AddWithIds(long n, const float *xdata, const long *xids) = 0; + + virtual size_t Count() const = 0; + + virtual size_t Size() const = 0; + + virtual Status Serialize() = 0; virtual ~Serializer() {} }; diff --git a/cpp/src/db/Status.h b/cpp/src/db/Status.h index f45c9f6bd1..4db2b4c6e0 100644 --- a/cpp/src/db/Status.h +++ b/cpp/src/db/Status.h @@ -21,6 +21,9 @@ public: static Status NotFound(const std::string& msg, const std::string& msg2="") { return Status(kNotFound, msg, msg2); } + static Status Error(const std::string& msg, const std::string& msg2="") { + return Status(kError, msg, msg2); + } static Status InvalidDBPath(const std::string& msg, const std::string& msg2="") { return Status(kInvalidDBPath, msg, msg2); @@ -35,6 +38,7 @@ public: bool ok() const { return state_ == nullptr; } bool IsNotFound() const { return code() == kNotFound; } + bool IsError() const { return code() == kError; } bool IsInvalidDBPath() const { return code() == kInvalidDBPath; } bool IsGroupError() const { return code() == kGroupError; } @@ -48,6 +52,7 @@ private: enum Code { kOK = 0, kNotFound, + kError, kInvalidDBPath, kGroupError,