mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
feat(db): add more api for serializer
Former-commit-id: d15d7dfecb9964ca2e3ba4e5b469137d1cc85057
This commit is contained in:
parent
cf19e90af3
commit
052f7e2f11
@ -1,5 +1,6 @@
|
||||
#include <easylogging++.h>
|
||||
#include <faiss/AutoTune.h>
|
||||
#include <wrapper/Index.h>
|
||||
|
||||
#include "FaissSerializer.h"
|
||||
|
||||
@ -9,13 +10,27 @@ namespace engine {
|
||||
|
||||
const std::string IndexType = "IDMap,Flat";
|
||||
|
||||
FaissSerializer::FaissSerializer(uint16_t dimension)
|
||||
: pIndex_(faiss::index_factory(dimension, IndexType.c_str())) {
|
||||
FaissSerializer::FaissSerializer(uint16_t dimension, const std::string& location)
|
||||
: pIndex_(faiss::index_factory(dimension, IndexType.c_str())),
|
||||
location_(location) {
|
||||
}
|
||||
|
||||
bool FaissSerializer::AddWithIds(long n, const float *xdata, const long *xids) {
|
||||
Status FaissSerializer::AddWithIds(long n, const float *xdata, const long *xids) {
|
||||
pIndex_->add_with_ids(n, xdata, xids);
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t FaissSerializer::Count() const {
|
||||
return (size_t)(pIndex_->ntotal);
|
||||
}
|
||||
|
||||
size_t FaissSerializer::Size() const {
|
||||
return (size_t)(Count() * pIndex_->d);
|
||||
}
|
||||
|
||||
Status FaissSerializer::Serialize() {
|
||||
write_index(pIndex_.get(), location_.c_str());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -15,11 +15,18 @@ namespace engine {
|
||||
|
||||
class FaissSerializer : public Serializer {
|
||||
public:
|
||||
FaissSerializer(uint16_t dimension);
|
||||
virtual bool AddWithIds(long n, const float *xdata, const long *xids) override;
|
||||
FaissSerializer(uint16_t dimension, const std::string& location);
|
||||
virtual Status AddWithIds(long n, const float *xdata, const long *xids) override;
|
||||
|
||||
virtual size_t Count() const override;
|
||||
|
||||
virtual size_t Size() const override;
|
||||
|
||||
virtual Status Serialize() override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<faiss::Index> pIndex_;
|
||||
std::string location_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -5,12 +5,12 @@ namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace engine {
|
||||
|
||||
bool Serializer::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
|
||||
Status Serializer::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
|
||||
long n1 = (long)vectors.size();
|
||||
long n2 = (long)vector_ids.size();
|
||||
if (n1 != n2) {
|
||||
LOG(ERROR) << "vectors size is not equal to the size of vector_ids: " << n1 << "!=" << n2;
|
||||
return false;
|
||||
return Status::Error("Error: AddWithIds");
|
||||
}
|
||||
return AddWithIds(n1, vectors.data(), vector_ids.data());
|
||||
}
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Status.h"
|
||||
|
||||
namespace zilliz {
|
||||
namespace vecwise {
|
||||
namespace engine {
|
||||
@ -9,10 +11,16 @@ namespace engine {
|
||||
class Serializer {
|
||||
public:
|
||||
|
||||
bool AddWithIds(const std::vector<float>& vectors,
|
||||
Status AddWithIds(const std::vector<float>& vectors,
|
||||
const std::vector<long>& vector_ids);
|
||||
|
||||
virtual bool AddWithIds(long n, const float *xdata, const long *xids) = 0;
|
||||
virtual Status AddWithIds(long n, const float *xdata, const long *xids) = 0;
|
||||
|
||||
virtual size_t Count() const = 0;
|
||||
|
||||
virtual size_t Size() const = 0;
|
||||
|
||||
virtual Status Serialize() = 0;
|
||||
|
||||
virtual ~Serializer() {}
|
||||
};
|
||||
|
||||
@ -21,6 +21,9 @@ public:
|
||||
static Status NotFound(const std::string& msg, const std::string& msg2="") {
|
||||
return Status(kNotFound, msg, msg2);
|
||||
}
|
||||
static Status Error(const std::string& msg, const std::string& msg2="") {
|
||||
return Status(kError, msg, msg2);
|
||||
}
|
||||
|
||||
static Status InvalidDBPath(const std::string& msg, const std::string& msg2="") {
|
||||
return Status(kInvalidDBPath, msg, msg2);
|
||||
@ -35,6 +38,7 @@ public:
|
||||
bool ok() const { return state_ == nullptr; }
|
||||
|
||||
bool IsNotFound() const { return code() == kNotFound; }
|
||||
bool IsError() const { return code() == kError; }
|
||||
|
||||
bool IsInvalidDBPath() const { return code() == kInvalidDBPath; }
|
||||
bool IsGroupError() const { return code() == kGroupError; }
|
||||
@ -48,6 +52,7 @@ private:
|
||||
enum Code {
|
||||
kOK = 0,
|
||||
kNotFound,
|
||||
kError,
|
||||
|
||||
kInvalidDBPath,
|
||||
kGroupError,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user