diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 7efbd54f0f..d1cc1ae4ff 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -134,6 +134,24 @@ IndexType VecIndexImpl::GetType() { return type; } +VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { + //if (auto new_type = GetGpuIndexType(type)) { + // auto device_index = index_->CopyToGpu(device_id); + // return std::make_shared(device_index, new_type); + //} + //return nullptr; + + // TODO(linxj): update type + auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg); + return std::make_shared(gpu_index, type); +} + +// TODO(linxj): rename copytocpu => copygputocpu +VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) { + auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg); + return std::make_shared(cpu_index, type); +} + float *BFIndex::GetRawVectors() { auto raw_index = std::dynamic_pointer_cast(index_); if (raw_index) { return raw_index->GetRawVectors(); } diff --git a/cpp/src/wrapper/knowhere/vec_impl.h b/cpp/src/wrapper/knowhere/vec_impl.h index c4a0e2ac61..5e46c16f70 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.h +++ b/cpp/src/wrapper/knowhere/vec_impl.h @@ -25,6 +25,8 @@ class VecIndexImpl : public VecIndex { const Config &cfg, const long &nt, const float *xt) override; + VecIndexPtr CopyToGpu(const int64_t &device_id, const Config &cfg) override; + VecIndexPtr CopyToCpu(const Config &cfg) override; IndexType GetType() override; int64_t Dimension() override; int64_t Count() override; diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index 80c8771dda..088228386c 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -35,6 +35,9 @@ enum class IndexType { NSG_MIX, }; +class VecIndex; +using VecIndexPtr = std::shared_ptr; + class VecIndex { public: virtual server::KnowhereError BuildAll(const long &nb, @@ -55,6 +58,11 @@ class VecIndex { long *ids, const Config &cfg = Config()) = 0; + virtual VecIndexPtr CopyToGpu(const int64_t& device_id, + const Config &cfg = Config()) = 0; + + virtual VecIndexPtr CopyToCpu(const Config &cfg = Config()) = 0; + virtual IndexType GetType() = 0; virtual int64_t Dimension() = 0; @@ -66,8 +74,6 @@ class VecIndex { virtual server::KnowhereError Load(const zilliz::knowhere::BinarySet &index_binary) = 0; }; -using VecIndexPtr = std::shared_ptr; - extern server::KnowhereError write_index(VecIndexPtr index, const std::string &location); extern VecIndexPtr read_index(const std::string &location);