diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index d1cc1ae4ff..f6bdd82618 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -8,6 +8,7 @@ #include "knowhere/index/vector_index/idmap.h" #include "knowhere/index/vector_index/gpu_ivf.h" #include "knowhere/common/exception.h" +#include "knowhere/index/vector_index/cloner.h" #include "vec_impl.h" #include "data_transfer.h" @@ -152,6 +153,22 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) { return std::make_shared(cpu_index, type); } +VecIndexPtr VecIndexImpl::Clone() { + auto clone_index = std::make_shared(index_->Clone(), type); + clone_index->dim = dim; + return clone_index; +} + +int64_t VecIndexImpl::GetDeviceId() { + if (auto device_idx = std::dynamic_pointer_cast(index_)){ + return device_idx->GetGpuDevice(); + } + else { + return -1; // -1 == cpu + } + return 0; +} + float *BFIndex::GetRawVectors() { auto raw_index = std::dynamic_pointer_cast(index_); if (raw_index) { return raw_index->GetRawVectors(); } diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index 5e46c16f70..f03f299f78 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -33,6 +33,8 @@ class VecIndexImpl : public VecIndex { server::KnowhereError Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override; zilliz::knowhere::BinarySet Serialize() override; server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) override; + VecIndexPtr Clone() override; + int64_t GetDeviceId() override; server::KnowhereError Search(const long &nq, const float *xq, float *dist, long *ids, const Config &cfg) override; protected: diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index 088228386c..19f0c6d360 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -63,6 +63,10 @@ class VecIndex { virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0; + virtual VecIndexPtr Clone() = 0; + + virtual int64_t GetDeviceId() = 0; + virtual IndexType GetType() = 0; virtual int64_t Dimension() = 0;