diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp index b519c28e04..5786440726 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp @@ -41,11 +41,9 @@ GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { idx_config.device = static_cast(gpu_id_); int32_t nlist = config[IndexParams::nlist]; faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); - auto device_index = - new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config); - device_index->train(rows, (float*)p_data); - - index_.reset(device_index); + index_ = std::make_shared(gpu_res->faiss_res.get(), dim, nlist, metric_type, + idx_config); + index_->train(rows, (float*)p_data); res_ = gpu_res; } else { KNOWHERE_THROW_MSG("Build IVF can't get gpu resource"); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp index e5449ae50d..2b251239de 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp @@ -38,10 +38,9 @@ GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { int32_t m = config[IndexParams::m]; int32_t nbits = config[IndexParams::nbits]; faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); - auto device_index = - new faiss::gpu::GpuIndexIVFPQ(gpu_res->faiss_res.get(), dim, nlist, m, nbits, metric_type, idx_config); - device_index->train(rows, (float*)p_data); - index_.reset(device_index); + index_ = std::make_shared(gpu_res->faiss_res.get(), dim, nlist, m, nbits, + metric_type, idx_config); + index_->train(rows, (float*)p_data); res_ = gpu_res; } else { KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource"); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp index deea119e76..b90bc87747 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp @@ -36,10 +36,9 @@ GPUIVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { idx_config.device = static_cast(gpu_id_); int32_t nlist = config[IndexParams::nlist]; faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); - auto device_index = new faiss::gpu::GpuIndexIVFScalarQuantizer( + index_ = std::make_shared( gpu_res->faiss_res.get(), dim, nlist, faiss::QuantizerType::QT_8bit, metric_type, true, idx_config); - device_index->train(rows, (float*)p_data); - index_.reset(device_index); + index_->train(rows, (float*)p_data); res_ = gpu_res; } else { KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource"); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp index b1e147f2ec..db63dbb157 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp @@ -12,8 +12,7 @@ #include #include -#include -#include +#include #include #include #include @@ -34,28 +33,22 @@ IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) { GETTENSOR(dataset_ptr) gpu_id_ = config[knowhere::meta::DEVICEID]; - std::stringstream index_type; - index_type << "IVF" << config[IndexParams::nlist] << "," - << "SQ8Hybrid"; - auto build_index = - faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get())); - auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); if (gpu_res != nullptr) { ResScope rs(gpu_res, gpu_id_, true); - auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index); - device_index->train(rows, (float*)p_data); - - index_.reset(device_index); + faiss::gpu::GpuIndexIVFSQHybridConfig idx_config; + idx_config.device = static_cast(gpu_id_); + int32_t nlist = config[IndexParams::nlist]; + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + index_ = std::make_shared( + gpu_res->faiss_res.get(), dim, nlist, faiss::QuantizerType::QT_8bit, metric_type, true, idx_config); + index_->train(rows, reinterpret_cast(p_data)); res_ = gpu_res; gpu_mode_ = 2; index_mode_ = IndexMode::MODE_GPU; } else { - delete build_index; KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource"); } - - delete build_index; } VecIndexPtr