feat(db): add more api for serializer

Former-commit-id: d15d7dfecb9964ca2e3ba4e5b469137d1cc85057
This commit is contained in:
Xu Peng 2019-04-30 15:43:08 +08:00
parent cf19e90af3
commit 052f7e2f11
5 changed files with 45 additions and 10 deletions

View File

@ -1,5 +1,6 @@
#include <easylogging++.h>
#include <faiss/AutoTune.h>
#include <wrapper/Index.h>
#include "FaissSerializer.h"
@ -9,13 +10,27 @@ namespace engine {
const std::string IndexType = "IDMap,Flat";
FaissSerializer::FaissSerializer(uint16_t dimension)
: pIndex_(faiss::index_factory(dimension, IndexType.c_str())) {
FaissSerializer::FaissSerializer(uint16_t dimension, const std::string& location)
: pIndex_(faiss::index_factory(dimension, IndexType.c_str())),
location_(location) {
}
bool FaissSerializer::AddWithIds(long n, const float *xdata, const long *xids) {
Status FaissSerializer::AddWithIds(long n, const float *xdata, const long *xids) {
pIndex_->add_with_ids(n, xdata, xids);
return true;
return Status::OK();
}
size_t FaissSerializer::Count() const {
return (size_t)(pIndex_->ntotal);
}
size_t FaissSerializer::Size() const {
return (size_t)(Count() * pIndex_->d);
}
Status FaissSerializer::Serialize() {
write_index(pIndex_.get(), location_.c_str());
return Status::OK();
}

View File

@ -15,11 +15,18 @@ namespace engine {
class FaissSerializer : public Serializer {
public:
FaissSerializer(uint16_t dimension);
virtual bool AddWithIds(long n, const float *xdata, const long *xids) override;
FaissSerializer(uint16_t dimension, const std::string& location);
virtual Status AddWithIds(long n, const float *xdata, const long *xids) override;
virtual size_t Count() const override;
virtual size_t Size() const override;
virtual Status Serialize() override;
protected:
std::shared_ptr<faiss::Index> pIndex_;
std::string location_;
};

View File

@ -5,12 +5,12 @@ namespace zilliz {
namespace vecwise {
namespace engine {
bool Serializer::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
Status Serializer::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
long n1 = (long)vectors.size();
long n2 = (long)vector_ids.size();
if (n1 != n2) {
LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2;
return false;
return Status::Error("Error: AddWithIds");
}
return AddWithIds(n1, vectors.data(), vector_ids.data());
}

View File

@ -2,6 +2,8 @@
#include <vector>
#include "Status.h"
namespace zilliz {
namespace vecwise {
namespace engine {
@ -9,10 +11,16 @@ namespace engine {
class Serializer {
public:
bool AddWithIds(const std::vector<float>& vectors,
Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
virtual bool AddWithIds(long n, const float *xdata, const long *xids) = 0;
virtual Status AddWithIds(long n, const float *xdata, const long *xids) = 0;
virtual size_t Count() const = 0;
virtual size_t Size() const = 0;
virtual Status Serialize() = 0;
virtual ~Serializer() {}
};

View File

@ -21,6 +21,9 @@ public:
static Status NotFound(const std::string& msg, const std::string& msg2="") {
return Status(kNotFound, msg, msg2);
}
static Status Error(const std::string& msg, const std::string& msg2="") {
return Status(kError, msg, msg2);
}
static Status InvalidDBPath(const std::string& msg, const std::string& msg2="") {
return Status(kInvalidDBPath, msg, msg2);
@ -35,6 +38,7 @@ public:
bool ok() const { return state_ == nullptr; }
bool IsNotFound() const { return code() == kNotFound; }
bool IsError() const { return code() == kError; }
bool IsInvalidDBPath() const { return code() == kInvalidDBPath; }
bool IsGroupError() const { return code() == kGroupError; }
@ -48,6 +52,7 @@ private:
enum Code {
kOK = 0,
kNotFound,
kError,
kInvalidDBPath,
kGroupError,